diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 00000000..4590aecc --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,40 @@ +# Contributors + +CRF-RNN specific code in this repository was written by: + ++ Sadeep Jayasumana (sadeep.jay@gmail.com)
++ Bernardino Romera-Paredes (bernardino.romeraparedes@eng.ox.ac.uk)
++ Shuai Zheng (kylezheng04@gmail.com)
++ Zhizhong Su (suzhizhong@baidu.com) + +Our code uses the [Permutohedral lattice library](http://graphics.stanford.edu/papers/permutohedral/), and the [Caffe future version](https://github.com/longjon/caffe/tree/future). +We also used parts of the [Dense CRF code](http://www.philkr.net/home/densecrf) while implementing this. + +Permutohedral lattice library (BSD license) is from Andrew Adams, Jongmin Baek, Abe Davis. Fast High-Dimensional Filtering Using the +Permutohedral Lattice. Eurographics 2010. +DenseCRF library is from from Philipp Krahenbuhl and Vladlen Koltun. Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials. +NIPS 2011. + +Our software is built on top of the Caffe software library. Below is a copy of its CONTRIBUTORS.md file. +Please note that, due to technical difficulties, we could not keep the history of the original contributors to Caffe code in our git repository. +Please refer to the original [Caffe git repository](https://github.com/BVLC/caffe) for this purpose. + +----------------------------------------------------------------------------------------------------------------- + +Caffe is developed by a core set of BVLC members and the open-source community. + +We thank all of our [contributors](https://github.com/BVLC/caffe/graphs/contributors)! + +**For the detailed history of contributions** of a given file, try + + git blame file + +to see line-by-line credits and + + git log --follow file + +to see the change log even across renames and rewrites. + +Please refer to the [acknowledgements](http://caffe.berkeleyvision.org/#acknowledgements) on the Caffe site for further details. + +**Copyright** is held by the original contributor according to the versioning history; see LICENSE. diff --git a/README.md b/README.md index f05f1b2f..5f401c69 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,121 @@ -# crfasrnn -This repository contains the source code for the semantic image segmentation method described in the ICCV 2015 paper: Conditional Random Fields as Recurrent Neural Networks. http://crfasrnn.torr.vision/ +# CRF-RNN for Semantic Image Segmentation +![sample](sample.png) + +Live demo: [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision) + +This package contains code for the "CRF-RNN" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). Our software is built on top of the [Caffe](http://caffe.berkeleyvision.org/) deep learning library. The current version was developed by: + +[Sadeep Jayasumana](http://www.robots.ox.ac.uk/~sadeep/), +[Shuai Zheng](http://kylezheng.org/), +[Bernardino Romera Paredes](http://romera-paredes.com/), and +[Zhizhong Su](suzhizhong@baidu.com). + +Supervisor: [Philip Torr](http://www.robots.ox.ac.uk/~tvg/) + +Our work allows computers to recognize objects in images, what is distinctive about our work is that we also recover the 2D outline of the object. + +Currently we have trained this model to recognize 20 classes. This software allows you to test our algorithm on your own images – have a try and see if you can fool it, if you get some good examples you can send them to us. + +Why are we doing this? This work is part of a project to build augmented reality glasses for the partially sighted. Please read about it here: [smart-specs](http://www.va-st.com/smart-specs/). + +For demo and more information about CRF-RNN please visit the project website . + +If you use this code/model for your research, please consider citing the following paper: +``` +@inproceedings{crfasrnn_ICCV2015, + author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr}, + title = {Conditional Random Fields as Recurrent Neural Networks}, + booktitle = {International Conference on Computer Vision (ICCV)}, + year = {2015} +} +``` + + +#Installation Guide + +You need to compile the modified Caffe library in this repository. Instructions for Ubuntu 14.04 are included below. You can also consult the generic [Caffe installation guide](http://caffe.berkeleyvision.org/installation.html). + + +###1.1 Install dependencies +#####General dependencies +``` +sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libhdf5-serial-dev protobuf-compiler +sudo apt-get install --no-install-recommends libboost-all-dev +``` + +#####CUDA +Install CUDA correct driver and its SDK. Download CUDA SDK from Nvidia website. + +In Ubuntu 14.04. You need to make sure the required tools are installed. You might need to blacklist the required modules so that they do not interfere with the driver installation. You also need to uninstall your default Nvidia Driver first. +``` +sudo apt-get install freeglut3-dev build-essential libx11-dev libxmu-dev libxi-dev libgl1-mesa-glx libglu1-mesa libglu1-mesa-dev +``` +open /etc/modprobe.d/blacklist.conf and add: +``` +blacklist amd76x_edac +blacklist vga16fb +blacklist nouveau +blacklist rivafb +blacklist nvidiafb +blacklist rivatv +``` +``` +sudo apt-get remove --purge nvidia* +``` + +When you restart your PC, before loging in, try "Ctrl+Alt+F1" switch to a text-based login. Try: +``` +sudo service lightdm stop +chmod +x cuda*.run +sudo ./cuda*.run +``` + +#####BLAS +Install ATLAS or OpenBLAS or MKL. + +#####Python +Install Anaconda Python distribution or install the default Python distribution with numpy/scipy/... + +#####MATLAB (optional) +Install MATLAB using a standard distribution. + +###1.2 Build the custom Caffe version +Set the path correctly in the Makefile.config. You can copy the Makefile.config.example to Makefile.config, as most common parts are filled already. You need to change it according to your environment. + +After this, in Ubuntu 14.04, try: +``` +make +``` + +If there are no error messages, you can then compile and install the python and matlab wrappers: +``` +make matcaffe +``` + +``` +make pycaffe +``` + +That's it! Enjoy our software! + + +###1.3 Run the demo +Matlab and Python scripts for running the demo are available in the matlab-scripts and python-scripts directories, respectively. You can choose either of them. Note that you should change the paths in the scripts according your environment. + +# LICENSE +CRF-RNN feature in Caffe is implemented for the paper: +Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, Chang Huang, Philip H. S. Torr. +Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + +Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, and Philip H. S. Torr are with University of Oxford. +Vibhav Vineet did this work when he was with the University of Oxford, he is now with the Stanford University. +Zhizhong Su, Dalong Du, Chang Huang are with the Baidu Institute of Deep Learning (IDL). + +CRF-RNN uses the Permutohedral lattice library, the DenseCRF library and the Caffe future version. + +Permutohedral lattice library (BSD license) is from Andrew Adams, Jongmin Baek, Abe Davis. Fast High-Dimensional Filtering Using the +Permutohedral Lattice. Eurographics 2010. +DenseCRF library from Philipp Krahenbuhl and Vladlen Koltun. Efficient Inference in Fully Connected CRFs with Gaussian Edge Potentials. +NIPS 2011. + +For more information about CRF-RNN please vist the project website http://crfasrnn.torr.vision. diff --git a/caffe-crfrnn/CMakeLists.txt b/caffe-crfrnn/CMakeLists.txt new file mode 100644 index 00000000..ef599b68 --- /dev/null +++ b/caffe-crfrnn/CMakeLists.txt @@ -0,0 +1,73 @@ +cmake_minimum_required(VERSION 2.8.7) + +# ---[ Caffe project +project(Caffe C CXX) + +# ---[ Using cmake scripts and modules +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + +include(ExternalProject) + +include(cmake/Utils.cmake) +include(cmake/Targets.cmake) +include(cmake/Misc.cmake) +include(cmake/Summary.cmake) +include(cmake/ConfigGen.cmake) + +# ---[ Options +caffe_option(CPU_ONLY "Build Caffe without CUDA support" OFF) # TODO: rename to USE_CUDA +caffe_option(USE_CUDNN "Build Caffe with cuDNN libary support" ON IF NOT CPU_ONLY) +caffe_option(BUILD_SHARED_LIBS "Build shared libraries" ON) +caffe_option(BUILD_python "Build Python wrapper" ON) +set(python_version "2" CACHE STRING "Specify which python version to use") +caffe_option(BUILD_matlab "Build Matlab wrapper" OFF IF UNIX OR APPLE) +caffe_option(BUILD_docs "Build documentation" ON IF UNIX OR APPLE) +caffe_option(BUILD_python_layer "Build the Caffe python layer" ON) + +# ---[ Dependencies +include(cmake/Dependencies.cmake) + +# ---[ Flags +if(UNIX OR APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall") +endif() + +if(USE_libstdcpp) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libstdc++") + message("-- Warning: forcing libstdc++ (controlled by USE_libstdcpp option in cmake)") +endif() + +add_definitions(-DGTEST_USE_OWN_TR1_TUPLE) + +# ---[ Warnings +caffe_warnings_disable(CMAKE_CXX_FLAGS -Wno-sign-compare -Wno-uninitialized) + +# ---[ Config generation +configure_file(cmake/Templates/caffe_config.h.in "${PROJECT_BINARY_DIR}/caffe_config.h") + +# ---[ Includes +set(Caffe_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) +include_directories(${Caffe_INCLUDE_DIR} ${PROJECT_BINARY_DIR}) +include_directories(BEFORE src) # This is needed for gtest. + +# ---[ Subdirectories +add_subdirectory(src/gtest) +add_subdirectory(src/caffe) +add_subdirectory(tools) +add_subdirectory(examples) +add_subdirectory(python) +add_subdirectory(matlab) +add_subdirectory(docs) + +# ---[ Linter target +add_custom_target(lint COMMAND ${CMAKE_COMMAND} -P ${PROJECT_SOURCE_DIR}/cmake/lint.cmake) + +# ---[ pytest target +add_custom_target(pytest COMMAND python${python_version} -m unittest discover -s caffe/test WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/python ) +add_dependencies(pytest pycaffe) + +# ---[ Configuration summary +caffe_print_configuration_summary() + +# ---[ Export configs generation +caffe_generate_export_configs() diff --git a/caffe-crfrnn/CMakeScripts/FindAtlas.cmake b/caffe-crfrnn/CMakeScripts/FindAtlas.cmake new file mode 100644 index 00000000..27657a6c --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindAtlas.cmake @@ -0,0 +1,61 @@ +# Find the Atlas (and Lapack) libraries +# +# The following variables are optionally searched for defaults +# Atlas_ROOT_DIR: Base directory where all Atlas components are found +# +# The following are set after configuration is done: +# Atlas_FOUND +# Atlas_INCLUDE_DIRS +# Atlas_LIBRARIES +# Atlas_LIBRARYRARY_DIRS + +set(Atlas_INCLUDE_SEARCH_PATHS + /usr/include/atlas + /usr/include/atlas-base + $ENV{Atlas_ROOT_DIR} + $ENV{Atlas_ROOT_DIR}/include +) + +set(Atlas_LIB_SEARCH_PATHS + /usr/lib/atlas + /usr/lib/atlas-base + $ENV{Atlas_ROOT_DIR} + $ENV{Atlas_ROOT_DIR}/lib +) + +find_path(Atlas_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS}) +find_path(Atlas_CLAPACK_INCLUDE_DIR NAMES clapack.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS}) +find_library(Atlas_CBLAS_LIBRARY NAMES ptcblas_r ptcblas cblas_r cblas PATHS ${Atlas_LIB_SEARCH_PATHS}) +find_library(Atlas_BLAS_LIBRARY NAMES atlas_r atlas PATHS ${Atlas_LIB_SEARCH_PATHS}) +find_library(Atlas_LAPACK_LIBRARY NAMES alapack_r alapack lapack_atlas PATHS ${Atlas_LIB_SEARCH_PATHS}) + +set(LOOKED_FOR + + Atlas_CBLAS_INCLUDE_DIR + Atlas_CLAPACK_INCLUDE_DIR + + Atlas_CBLAS_LIBRARY + Atlas_BLAS_LIBRARY + Atlas_LAPACK_LIBRARY +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Atlas DEFAULT_MSG ${LOOKED_FOR}) + +if(ATLAS_FOUND) + + mark_as_advanced(${LOOKED_FOR}) + + set(Atlas_INCLUDE_DIR + ${Atlas_CBLAS_INCLUDE_DIR} + ${Atlas_CLAPACK_INCLUDE_DIR} + ) + + set(Atlas_LIBRARIES + ${Atlas_LAPACK_LIBRARY} + ${Atlas_CBLAS_LIBRARY} + ${Atlas_BLAS_LIBRARY} + ) + +endif(ATLAS_FOUND) + diff --git a/caffe-crfrnn/CMakeScripts/FindGFlags.cmake b/caffe-crfrnn/CMakeScripts/FindGFlags.cmake new file mode 100644 index 00000000..f93c5713 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindGFlags.cmake @@ -0,0 +1,48 @@ +# - Try to find GFLAGS +# +# The following variables are optionally searched for defaults +# GFLAGS_ROOT_DIR: Base directory where all GFLAGS components are found +# +# The following are set after configuration is done: +# GFLAGS_FOUND +# GFLAGS_INCLUDE_DIRS +# GFLAGS_LIBRARIES +# GFLAGS_LIBRARYRARY_DIRS + +include(FindPackageHandleStandardArgs) + +set(GFLAGS_ROOT_DIR "" CACHE PATH "Folder contains Gflags") + +# We are testing only a couple of files in the include directories +if(WIN32) + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}/src/windows) +else() + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}) +endif() + +if(MSVC) + find_library(GFLAGS_LIBRARY_RELEASE + NAMES libgflags + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Release) + + find_library(GFLAGS_LIBRARY_DEBUG + NAMES libgflags-debug + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Debug) + + set(GFLAGS_LIBRARY optimized ${GFLAGS_LIBRARY_RELEASE} debug ${GFLAGS_LIBRARY_DEBUG}) +else() + find_library(GFLAGS_LIBRARY gflags) +endif() + +find_package_handle_standard_args(GFLAGS DEFAULT_MSG + GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY) + + +if(GFLAGS_FOUND) + set(GFLAGS_INCLUDE_DIRS ${GFLAGS_INCLUDE_DIR}) + set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY}) +endif() diff --git a/caffe-crfrnn/CMakeScripts/FindGlog.cmake b/caffe-crfrnn/CMakeScripts/FindGlog.cmake new file mode 100644 index 00000000..0dc30abd --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindGlog.cmake @@ -0,0 +1,48 @@ +# - Try to find Glog +# +# The following variables are optionally searched for defaults +# GLOG_ROOT_DIR: Base directory where all GLOG components are found +# +# The following are set after configuration is done: +# GLOG_FOUND +# GLOG_INCLUDE_DIRS +# GLOG_LIBRARIES +# GLOG_LIBRARYRARY_DIRS + +include(FindPackageHandleStandardArgs) + +set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog") + +if(WIN32) + find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_ROOT_DIR}/src/windows) +else() + find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_ROOT_DIR}) +endif() + +if(MSVC) + find_library(GLOG_LIBRARY_RELEASE libglog_static + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES Release) + + find_library(GLOG_LIBRARY_DEBUG libglog_static + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES Debug) + + set(GLOG_LIBRARY optimized ${GLOG_LIBRARY_RELEASE} debug ${GLOG_LIBRARY_DEBUG}) +else() + find_library(GLOG_LIBRARY glog + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES + lib + lib64) +endif() + +find_package_handle_standard_args(GLOG DEFAULT_MSG + GLOG_INCLUDE_DIR GLOG_LIBRARY) + +if(GLOG_FOUND) + set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) + set(GLOG_LIBRARIES ${GLOG_LIBRARY}) +endif() diff --git a/caffe-crfrnn/CMakeScripts/FindLAPACK.cmake b/caffe-crfrnn/CMakeScripts/FindLAPACK.cmake new file mode 100644 index 00000000..9641c45d --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindLAPACK.cmake @@ -0,0 +1,190 @@ +# - Find LAPACK library +# This module finds an installed fortran library that implements the LAPACK +# linear-algebra interface (see http://www.netlib.org/lapack/). +# +# The approach follows that taken for the autoconf macro file, acx_lapack.m4 +# (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). +# +# This module sets the following variables: +# LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found +# LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK + +# Note: I do not think it is a good idea to mixup different BLAS/LAPACK versions +# Hence, this script wants to find a Lapack library matching your Blas library + +# Do nothing if LAPACK was found before +IF(NOT LAPACK_FOUND) + +SET(LAPACK_LIBRARIES) +SET(LAPACK_INFO) + +IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS) +ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS REQUIRED) +ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + +# Old search lapack script +include(CheckFortranFunctionExists) + +macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + if(_libraries_work) + if (WIN32) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) + else (WIN32) + if(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV DYLD_LIBRARY_PATH) + else(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV LD_LIBRARY_PATH) + endif(APPLE) + endif(WIN32) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif(_libraries_work) + endforeach(_library ${_list}) + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) + if (CMAKE_Fortran_COMPILER_WORKS) + check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) + else (CMAKE_Fortran_COMPILER_WORKS) + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif (CMAKE_Fortran_COMPILER_WORKS) + set(CMAKE_REQUIRED_LIBRARIES) + mark_as_advanced(${_prefix}${_combined_name}_WORKS) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif(_libraries_work) + if(NOT _libraries_work) + set(${LIBRARIES} FALSE) + endif(NOT _libraries_work) +endmacro(Check_Lapack_Libraries) + + +if(BLAS_FOUND) + + # Intel MKL + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "mkl")) + IF(MKL_LAPACK_LIBRARIES) + SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) + ELSE(MKL_LAPACK_LIBRARIES) + SET(LAPACK_LIBRARIES ${MKL_LIBRARIES}) + ENDIF(MKL_LAPACK_LIBRARIES) + SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) + SET(LAPACK_INFO "mkl") + ENDIF() + + # OpenBlas + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" OPEN_LAPACK_WORKS) + if(OPEN_LAPACK_WORKS) + SET(LAPACK_INFO "open") + else() + message(STATUS "It seems OpenBlas has not been compiled with Lapack support") + endif() + endif() + + # GotoBlas + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "goto")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" GOTO_LAPACK_WORKS) + if(GOTO_LAPACK_WORKS) + SET(LAPACK_INFO "goto") + else() + message(STATUS "It seems GotoBlas has not been compiled with Lapack support") + endif() + endif() + + # ACML + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "acml")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" ACML_LAPACK_WORKS) + if(ACML_LAPACK_WORKS) + SET(LAPACK_INFO "acml") + else() + message(STATUS "Strangely, this ACML library does not support Lapack?!") + endif() + endif() + + # Accelerate + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "accelerate")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" ACCELERATE_LAPACK_WORKS) + if(ACCELERATE_LAPACK_WORKS) + SET(LAPACK_INFO "accelerate") + else() + message(STATUS "Strangely, this Accelerate library does not support Lapack?!") + endif() + endif() + + # vecLib + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "veclib")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" VECLIB_LAPACK_WORKS) + if(VECLIB_LAPACK_WORKS) + SET(LAPACK_INFO "veclib") + else() + message(STATUS "Strangely, this vecLib library does not support Lapack?!") + endif() + endif() + + # Generic LAPACK library? + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "generic")) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "lapack" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "generic") + endif(LAPACK_LIBRARIES) + endif() + +else(BLAS_FOUND) + message(STATUS "LAPACK requires BLAS") +endif(BLAS_FOUND) + +if(LAPACK_INFO) + set(LAPACK_FOUND TRUE) +else(LAPACK_INFO) + set(LAPACK_FOUND FALSE) +endif(LAPACK_INFO) + +IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) + message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") +ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) +IF(NOT LAPACK_FIND_QUIETLY) + IF(LAPACK_FOUND) + MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") + ELSE(LAPACK_FOUND) + MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") + ENDIF(LAPACK_FOUND) +ENDIF(NOT LAPACK_FIND_QUIETLY) + +# Do nothing if LAPACK was found before +ENDIF(NOT LAPACK_FOUND) diff --git a/caffe-crfrnn/CMakeScripts/FindLMDB.cmake b/caffe-crfrnn/CMakeScripts/FindLMDB.cmake new file mode 100644 index 00000000..e615f542 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindLMDB.cmake @@ -0,0 +1,28 @@ +# Try to find the LMBD libraries and headers +# LMDB_FOUND - system has LMDB lib +# LMDB_INCLUDE_DIR - the LMDB include directory +# LMDB_LIBRARIES - Libraries needed to use LMDB + +# FindCWD based on FindGMP by: +# Copyright (c) 2006, Laurent Montel, +# +# Redistribution and use is allowed according to the terms of the BSD license. + +# Adapted from FindCWD by: +# Copyright 2013 Conrad Steenberg +# Aug 31, 2013 + +if (LMDB_INCLUDE_DIR AND LMDB_LIBRARIES) + # Already in cache, be silent + set(LMDB_FIND_QUIETLY TRUE) +endif (LMDB_INCLUDE_DIR AND LMDB_LIBRARIES) + +find_path(LMDB_INCLUDE_DIR NAMES "lmdb.h" HINTS "$ENV{LMDB_DIR}/include") +find_library(LMDB_LIBRARIES NAMES lmdb HINTS $ENV{LMDB_DIR}/lib ) +MESSAGE(STATUS "LMDB lib: " ${LMDB_LIBRARIES} ) +MESSAGE(STATUS "LMDB include: " ${LMDB_INCLUDE} ) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(LMDB DEFAULT_MSG LMDB_INCLUDE_DIR LMDB_LIBRARIES) + +mark_as_advanced(LMDB_INCLUDE_DIR LMDB_LIBRARIES) diff --git a/caffe-crfrnn/CMakeScripts/FindLevelDB.cmake b/caffe-crfrnn/CMakeScripts/FindLevelDB.cmake new file mode 100644 index 00000000..f3386f26 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindLevelDB.cmake @@ -0,0 +1,37 @@ +# - Find LevelDB +# +# LEVELDB_INCLUDE - Where to find leveldb/db.h +# LEVELDB_LIBS - List of libraries when using LevelDB. +# LEVELDB_FOUND - True if LevelDB found. + +get_filename_component(module_file_path ${CMAKE_CURRENT_LIST_FILE} PATH) + +# Look for the header file. +find_path(LEVELDB_INCLUDE NAMES leveldb/db.h PATHS $ENV{LEVELDB_ROOT}/include /opt/local/include /usr/local/include /usr/include DOC "Path in which the file leveldb/db.h is located." ) +mark_as_advanced(LEVELDB_INCLUDE) + +# Look for the library. +# Does this work on UNIX systems? (LINUX) +find_library(LEVELDB_LIBS NAMES leveldb PATHS /usr/lib $ENV{LEVELDB_ROOT}/lib DOC "Path to leveldb library." ) +mark_as_advanced(LEVELDB_LIBS) + +# Copy the results to the output variables. +if (LEVELDB_INCLUDE AND LEVELDB_LIBS) + message(STATUS "Found leveldb in ${LEVELDB_INCLUDE} ${LEVELDB_LIBS}") + set(LEVELDB_FOUND 1) + include(CheckCXXSourceCompiles) + set(CMAKE_REQUIRED_LIBRARY ${LEVELDB_LIBS} pthread) + set(CMAKE_REQUIRED_INCLUDES ${LEVELDB_INCLUDE}) + else () + set(LEVELDB_FOUND 0) + endif () + + # Report the results. + if (NOT LEVELDB_FOUND) + set(LEVELDB_DIR_MESSAGE "LEVELDB was not found. Make sure LEVELDB_LIBS and LEVELDB_INCLUDE are set.") + if (LEVELDB_FIND_REQUIRED) + message(FATAL_ERROR "${LEVELDB_DIR_MESSAGE}") + elseif (NOT LEVELDB_FIND_QUIETLY) + message(STATUS "${LEVELDB_DIR_MESSAGE}") + endif () + endif () \ No newline at end of file diff --git a/caffe-crfrnn/CMakeScripts/FindMKL.cmake b/caffe-crfrnn/CMakeScripts/FindMKL.cmake new file mode 100644 index 00000000..eb2d9f88 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindMKL.cmake @@ -0,0 +1,113 @@ +# - Find Intel MKL +# Find the MKL libraries +# +# Options: +# +# MKL_STATAIC : use static linking +# MKL_MULTI_THREADED: use multi-threading +# MKL_SDL : Single Dynamic Library interface +# +# This module defines the following variables: +# +# MKL_FOUND : True if MKL_INCLUDE_DIR are found +# MKL_INCLUDE_DIR : where to find mkl.h, etc. +# MKL_INCLUDE_DIRS : set when MKL_INCLUDE_DIR found +# MKL_LIBRARIES : the library to link against. + + +include(FindPackageHandleStandardArgs) + +set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs") +set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL") + +# Find include dir +find_path(MKL_INCLUDE_DIR mkl.h + PATHS ${MKL_ROOT}/include) + +# Find include directory +# There is no include folder under linux +if(WIN32) + find_path(INTEL_INCLUDE_DIR omp.h + PATHS ${INTEL_ROOT}/include) + set(MKL_INCLUDE_DIR ${MKL_INCLUDE_DIR} ${INTEL_INCLUDE_DIR}) +endif() + +# Find libraries + +# Handle suffix +set(_MKL_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) + +if(WIN32) + if(MKL_STATAIC) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib) + else() + set(CMAKE_FIND_LIBRARY_SUFFIXES _dll.lib) + endif() +else() + if(MKL_STATAIC) + set(CMAKE_FIND_LIBRARY_SUFFIXES .a) + else() + set(CMAKE_FIND_LIBRARY_SUFFIXES .so) + endif() +endif() + + +# MKL is composed by four layers: Interface, Threading, Computational and RTL + +if(MKL_SDL) + find_library(MKL_LIBRARY mkl_rt + PATHS ${MKL_ROOT}/lib/ia32/) + + set(MKL_MINIMAL_LIBRARY ${MKL_LIBRARY}) +else() + ######################### Interface layer ####################### + if(WIN32) + set(MKL_INTERFACE_LIBNAME mkl_intel_c) + else() + set(MKL_INTERFACE_LIBNAME mkl_intel) + endif() + + find_library(MKL_INTERFACE_LIBRARY ${MKL_INTERFACE_LIBNAME} + PATHS ${MKL_ROOT}/lib/ia32/) + + ######################## Threading layer ######################## + if(MKL_MULTI_THREADED) + set(MKL_THREADING_LIBNAME mkl_intel_thread) + else() + set(MKL_THREADING_LIBNAME mkl_sequential) + endif() + + find_library(MKL_THREADING_LIBRARY ${MKL_THREADING_LIBNAME} + PATHS ${MKL_ROOT}/lib/ia32/) + + ####################### Computational layer ##################### + find_library(MKL_CORE_LIBRARY mkl_core + PATHS ${MKL_ROOT}/lib/ia32/) + find_library(MKL_FFT_LIBRARY mkl_cdft_core + PATHS ${MKL_ROOT}/lib/ia32/) + find_library(MKL_SCALAPACK_LIBRARY mkl_scalapack_core + PATHS ${MKL_ROOT}/lib/ia32/) + + ############################ RTL layer ########################## + if(WIN32) + set(MKL_RTL_LIBNAME libiomp5md) + else() + set(MKL_RTL_LIBNAME libiomp5) + endif() + find_library(MKL_RTL_LIBRARY ${MKL_RTL_LIBNAME} + PATHS ${INTEL_RTL_ROOT}/lib) + + set(MKL_LIBRARY ${MKL_INTERFACE_LIBRARY} ${MKL_THREADING_LIBRARY} ${MKL_CORE_LIBRARY} ${MKL_FFT_LIBRARY} ${MKL_SCALAPACK_LIBRARY} ${MKL_RTL_LIBRARY}) + set(MKL_MINIMAL_LIBRARY ${MKL_INTERFACE_LIBRARY} ${MKL_THREADING_LIBRARY} ${MKL_CORE_LIBRARY} ${MKL_RTL_LIBRARY}) +endif() + +set(CMAKE_FIND_LIBRARY_SUFFIXES ${_MKL_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES}) + +find_package_handle_standard_args(MKL DEFAULT_MSG + MKL_INCLUDE_DIR MKL_LIBRARY MKL_MINIMAL_LIBRARY) + +if(MKL_FOUND) + set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR}) + set(MKL_LIBRARIES ${MKL_LIBRARY}) + set(MKL_MINIMAL_LIBRARIES ${MKL_LIBRARY}) +endif() diff --git a/caffe-crfrnn/CMakeScripts/FindNumPy.cmake b/caffe-crfrnn/CMakeScripts/FindNumPy.cmake new file mode 100644 index 00000000..baf21541 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindNumPy.cmake @@ -0,0 +1,103 @@ +# - Find the NumPy libraries +# This module finds if NumPy is installed, and sets the following variables +# indicating where it is. +# +# TODO: Update to provide the libraries and paths for linking npymath lib. +# +# NUMPY_FOUND - was NumPy found +# NUMPY_VERSION - the version of NumPy found as a string +# NUMPY_VERSION_MAJOR - the major version number of NumPy +# NUMPY_VERSION_MINOR - the minor version number of NumPy +# NUMPY_VERSION_PATCH - the patch version number of NumPy +# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 +# NUMPY_INCLUDE_DIRS - path to the NumPy include files + +#============================================================================ +# Copyright 2012 Continuum Analytics, Inc. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files +# (the "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be included +# in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +#============================================================================ + +# Finding NumPy involves calling the Python interpreter +if(NumPy_FIND_REQUIRED) + find_package(PythonInterp REQUIRED) +else() + find_package(PythonInterp) +endif() + +if(NOT PYTHONINTERP_FOUND) + set(NUMPY_FOUND FALSE) + return() +endif() + +execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "import numpy as n; print(n.__version__); print(n.get_include());" + RESULT_VARIABLE _NUMPY_SEARCH_SUCCESS + OUTPUT_VARIABLE _NUMPY_VALUES_OUTPUT + ERROR_VARIABLE _NUMPY_ERROR_VALUE + OUTPUT_STRIP_TRAILING_WHITESPACE) + +if(NOT _NUMPY_SEARCH_SUCCESS MATCHES 0) + if(NumPy_FIND_REQUIRED) + message(FATAL_ERROR + "NumPy import failure:\n${_NUMPY_ERROR_VALUE}") + endif() + set(NUMPY_FOUND FALSE) + return() +endif() + +# Convert the process output into a list +string(REGEX REPLACE ";" "\\\\;" _NUMPY_VALUES ${_NUMPY_VALUES_OUTPUT}) +string(REGEX REPLACE "\n" ";" _NUMPY_VALUES ${_NUMPY_VALUES}) +# Just in case there is unexpected output from the Python command. +list(GET _NUMPY_VALUES -2 NUMPY_VERSION) +list(GET _NUMPY_VALUES -1 NUMPY_INCLUDE_DIRS) + +string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" _VER_CHECK "${NUMPY_VERSION}") +if("${_VER_CHECK}" STREQUAL "") + # The output from Python was unexpected. Raise an error always + # here, because we found NumPy, but it appears to be corrupted somehow. + message(FATAL_ERROR + "Requested version and include path from NumPy, got instead:\n${_NUMPY_VALUES_OUTPUT}\n") + return() +endif() + +# Make sure all directory separators are '/' +string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIRS ${NUMPY_INCLUDE_DIRS}) + +# Get the major and minor version numbers +string(REGEX REPLACE "\\." ";" _NUMPY_VERSION_LIST ${NUMPY_VERSION}) +list(GET _NUMPY_VERSION_LIST 0 NUMPY_VERSION_MAJOR) +list(GET _NUMPY_VERSION_LIST 1 NUMPY_VERSION_MINOR) +list(GET _NUMPY_VERSION_LIST 2 NUMPY_VERSION_PATCH) +string(REGEX MATCH "[0-9]*" NUMPY_VERSION_PATCH ${NUMPY_VERSION_PATCH}) +math(EXPR NUMPY_VERSION_DECIMAL + "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") + +find_package_message(NUMPY + "Found NumPy: version \"${NUMPY_VERSION}\" ${NUMPY_INCLUDE_DIRS}" + "${NUMPY_INCLUDE_DIRS}${NUMPY_VERSION}") + +set(NUMPY_FOUND TRUE) + + diff --git a/caffe-crfrnn/CMakeScripts/FindOpenBLAS.cmake b/caffe-crfrnn/CMakeScripts/FindOpenBLAS.cmake new file mode 100644 index 00000000..b8434927 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindOpenBLAS.cmake @@ -0,0 +1,62 @@ + + +SET(Open_BLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include +) + +SET(Open_BLAS_LIB_SEARCH_PATHS + /lib/ + /lib/openblas-base + /lib64/ + /usr/lib + /usr/lib/openblas-base + /usr/lib64 + /usr/local/lib + /usr/local/lib64 + /opt/OpenBLAS/lib + $ENV{OpenBLAS}cd + $ENV{OpenBLAS}/lib + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/lib + ) + +FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS}) +FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS}) + +SET(OpenBLAS_FOUND ON) + +# Check include files +IF(NOT OpenBLAS_INCLUDE_DIR) + SET(OpenBLAS_FOUND OFF) + MESSAGE(STATUS "Could not find OpenBLAS include. Turning OpenBLAS_FOUND off") +ENDIF() + +# Check libraries +IF(NOT OpenBLAS_LIB) + SET(OpenBLAS_FOUND OFF) + MESSAGE(STATUS "Could not find OpenBLAS lib. Turning OpenBLAS_FOUND off") +ENDIF() + +IF (OpenBLAS_FOUND) + IF (NOT OpenBLAS_FIND_QUIETLY) + MESSAGE(STATUS "Found OpenBLAS libraries: ${OpenBLAS_LIB}") + MESSAGE(STATUS "Found OpenBLAS include: ${OpenBLAS_INCLUDE_DIR}") + ENDIF (NOT OpenBLAS_FIND_QUIETLY) +ELSE (OpenBLAS_FOUND) + IF (OpenBLAS_FIND_REQUIRED) + MESSAGE(FATAL_ERROR "Could not find OpenBLAS") + ENDIF (OpenBLAS_FIND_REQUIRED) +ENDIF (OpenBLAS_FOUND) + +MARK_AS_ADVANCED( + OpenBLAS_INCLUDE_DIR + OpenBLAS_LIB + OpenBLAS +) + diff --git a/caffe-crfrnn/CMakeScripts/FindProtobuf.cmake b/caffe-crfrnn/CMakeScripts/FindProtobuf.cmake new file mode 100644 index 00000000..0f94f498 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindProtobuf.cmake @@ -0,0 +1,152 @@ +# Locate and configure the Google Protocol Buffers library. +# Defines the following variables: +# +# PROTOBUF_FOUND - Found the Google Protocol Buffers library +# PROTOBUF_INCLUDE_DIRS - Include directories for Google Protocol Buffers +# PROTOBUF_LIBRARIES - The protobuf library +# +# The following cache variables are also defined: +# PROTOBUF_LIBRARY - The protobuf library +# PROTOBUF_PROTOC_LIBRARY - The protoc library +# PROTOBUF_INCLUDE_DIR - The include directory for protocol buffers +# PROTOBUF_PROTOC_EXECUTABLE - The protoc compiler +# +# ==================================================================== +# Example: +# +# find_package(Protobuf REQUIRED) +# include_directories(${PROTOBUF_INCLUDE_DIRS}) +# +# include_directories(${CMAKE_CURRENT_BINARY_DIR}) +# PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS foo.proto) +# add_executable(bar bar.cc ${PROTO_SRCS} ${PROTO_HDRS}) +# target_link_libraries(bar ${PROTOBUF_LIBRARY}) +# +# NOTE: You may need to link against pthreads, depending +# on the platform. +# ==================================================================== +# +# PROTOBUF_GENERATE_CPP (public function) +# SRCS = Variable to define with autogenerated +# source files +# HDRS = Variable to define with autogenerated +# header files +# ARGN = proto files +# +# ==================================================================== + + +#============================================================================= +# Copyright 2009 Kitware, Inc. +# Copyright 2009 Philip Lowman +# Copyright 2008 Esben Mose Hansen, Ange Optimization ApS +# +# Distributed under the OSI-approved BSD License (the "License"); +# see accompanying file Copyright.txt for details. +# +# This software is distributed WITHOUT ANY WARRANTY; without even the +# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the License for more information. +#============================================================================= +# (To distributed this file outside of CMake, substitute the full +# License text for the above reference.) + +function(PROTOBUF_GENERATE_PYTHON SRCS) + if(NOT ARGN) + message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files") + return() + endif(NOT ARGN) + + set(${SRCS}) + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py") + + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --python_out ${CMAKE_CURRENT_BINARY_DIR} --proto_path ${CMAKE_CURRENT_SOURCE_DIR} +${ABS_FIL} + DEPENDS ${ABS_FIL} + COMMENT "Running Python protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + + set_source_files_properties(${${SRCS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) +endfunction() + + +function(PROTOBUF_GENERATE_CPP SRCS HDRS) + if(NOT ARGN) + message(SEND_ERROR "Error: PROTOBUF_GENERATE_CPP() called without any proto files") + return() + endif(NOT ARGN) + + set(${SRCS}) + set(${HDRS}) + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + + list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") + list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") + + add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} + ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --proto_path ${CMAKE_CURRENT_SOURCE_DIR} +${ABS_FIL} + DEPENDS ${ABS_FIL} + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM ) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) +endfunction() + + +find_path(PROTOBUF_INCLUDE_DIR google/protobuf/service.h) + +# Google's provided vcproj files generate libraries with a "lib" +# prefix on Windows +if(WIN32) + set(PROTOBUF_ORIG_FIND_LIBRARY_PREFIXES "${CMAKE_FIND_LIBRARY_PREFIXES}") + set(CMAKE_FIND_LIBRARY_PREFIXES "lib" "") +endif() + +find_library(PROTOBUF_LIBRARY NAMES protobuf + DOC "The Google Protocol Buffers Library" +) +find_library(PROTOBUF_PROTOC_LIBRARY NAMES protoc + DOC "The Google Protocol Buffers Compiler Library" +) +find_program(PROTOBUF_PROTOC_EXECUTABLE NAMES protoc + DOC "The Google Protocol Buffers Compiler" +) + +mark_as_advanced(PROTOBUF_INCLUDE_DIR + PROTOBUF_LIBRARY + PROTOBUF_PROTOC_LIBRARY + PROTOBUF_PROTOC_EXECUTABLE) + +# Restore original find library prefixes +if(WIN32) + set(CMAKE_FIND_LIBRARY_PREFIXES "${PROTOBUF_ORIG_FIND_LIBRARY_PREFIXES}") +endif() + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(PROTOBUF DEFAULT_MSG + PROTOBUF_LIBRARY PROTOBUF_INCLUDE_DIR) + +if(PROTOBUF_FOUND) + set(PROTOBUF_INCLUDE_DIRS ${PROTOBUF_INCLUDE_DIR}) + set(PROTOBUF_LIBRARIES ${PROTOBUF_LIBRARY}) +endif() diff --git a/caffe-crfrnn/CMakeScripts/FindSnappy.cmake b/caffe-crfrnn/CMakeScripts/FindSnappy.cmake new file mode 100644 index 00000000..d769b442 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/FindSnappy.cmake @@ -0,0 +1,33 @@ +# Find the Snappy libraries +# +# The following variables are optionally searched for defaults +# Snappy_ROOT_DIR: Base directory where all Snappy components are found +# +# The following are set after configuration is done: +# Snappy_FOUND +# Snappy_INCLUDE_DIRS +# Snappy_LIBS + +find_path(SNAPPY_INCLUDE_DIR + NAMES snappy.h + HINTS ${SNAPPY_ROOT_DIR} + ${SNAPPY_ROOT_DIR}/include +) + +find_library(SNAPPY_LIBS + NAMES snappy + HINTS ${SNAPPY_ROOT_DIR} + ${SNAPPY_ROOT_DIR}/lib +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Snappy + DEFAULT_MSG + SNAPPY_LIBS + SNAPPY_INCLUDE_DIR +) + +mark_as_advanced( + SNAPPY_LIBS + SNAPPY_INCLUDE_DIR +) diff --git a/caffe-crfrnn/CMakeScripts/lint.cmake b/caffe-crfrnn/CMakeScripts/lint.cmake new file mode 100644 index 00000000..04df3409 --- /dev/null +++ b/caffe-crfrnn/CMakeScripts/lint.cmake @@ -0,0 +1,48 @@ + +set(CMAKE_SOURCE_DIR ../) +set(LINT_COMMAND ${CMAKE_SOURCE_DIR}/scripts/cpp_lint.py) +set(SRC_FILE_EXTENSIONS h hpp hu c cpp cu cc) +set(EXCLUDE_FILE_EXTENSTIONS pb.h pb.cc) +set(LINT_DIRS include src/caffe examples tools python matlab) + +# find all files of interest +foreach(ext ${SRC_FILE_EXTENSIONS}) + foreach(dir ${LINT_DIRS}) + file(GLOB_RECURSE FOUND_FILES ${CMAKE_SOURCE_DIR}/${dir}/*.${ext}) + set(LINT_SOURCES ${LINT_SOURCES} ${FOUND_FILES}) + endforeach() +endforeach() + +# find all files that should be excluded +foreach(ext ${EXCLUDE_FILE_EXTENSTIONS}) + file(GLOB_RECURSE FOUND_FILES ${CMAKE_SOURCE_DIR}/*.${ext}) + set(EXCLUDED_FILES ${EXCLUDED_FILES} ${FOUND_FILES}) +endforeach() + +# exclude generated pb files +list(REMOVE_ITEM LINT_SOURCES ${EXCLUDED_FILES}) + +execute_process( + COMMAND ${LINT_COMMAND} ${LINT_SOURCES} + ERROR_VARIABLE LINT_OUTPUT + ERROR_STRIP_TRAILING_WHITESPACE +) + +string(REPLACE "\n" ";" LINT_OUTPUT ${LINT_OUTPUT}) + +list(GET LINT_OUTPUT -1 LINT_RESULT) +list(REMOVE_AT LINT_OUTPUT -1) +string(REPLACE " " ";" LINT_RESULT ${LINT_RESULT}) +list(GET LINT_RESULT -1 NUM_ERRORS) +if(NUM_ERRORS GREATER 0) + foreach(msg ${LINT_OUTPUT}) + string(FIND ${msg} "Done" result) + if(result LESS 0) + message(STATUS ${msg}) + endif() + endforeach() + message(FATAL_ERROR "Lint found ${NUM_ERRORS} errors!") +else() + message(STATUS "Lint did not find any errors!") +endif() + diff --git a/caffe-crfrnn/INSTALL.md b/caffe-crfrnn/INSTALL.md new file mode 100644 index 00000000..42fcf027 --- /dev/null +++ b/caffe-crfrnn/INSTALL.md @@ -0,0 +1,7 @@ +# Installation + +See http://caffe.berkeleyvision.org/installation.html for the latest +installation instructions. + +Check the issue tracker in case you need help: +https://github.com/BVLC/caffe/issues diff --git a/caffe-crfrnn/LICENSE b/caffe-crfrnn/LICENSE new file mode 100644 index 00000000..9a715d03 --- /dev/null +++ b/caffe-crfrnn/LICENSE @@ -0,0 +1,49 @@ +COPYRIGHT +All contributions by the University of Oxford: +Copyright (c) 2015, All rights reserved. + +All contributions by Baidu Institute of Deep Learning: +Copyright (c) 2015, All rights reserved. + +All contributions by the University of California: +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. diff --git a/caffe-crfrnn/Makefile b/caffe-crfrnn/Makefile new file mode 100644 index 00000000..ea97392d --- /dev/null +++ b/caffe-crfrnn/Makefile @@ -0,0 +1,584 @@ +PROJECT := caffe + +CONFIG_FILE := Makefile.config +include $(CONFIG_FILE) + +BUILD_DIR_LINK := $(BUILD_DIR) +RELEASE_BUILD_DIR ?= .$(BUILD_DIR)_release +DEBUG_BUILD_DIR ?= .$(BUILD_DIR)_debug + +DEBUG ?= 0 +ifeq ($(DEBUG), 1) + BUILD_DIR := $(DEBUG_BUILD_DIR) + OTHER_BUILD_DIR := $(RELEASE_BUILD_DIR) +else + BUILD_DIR := $(RELEASE_BUILD_DIR) + OTHER_BUILD_DIR := $(DEBUG_BUILD_DIR) +endif + +# All of the directories containing code. +SRC_DIRS := $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \ + \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print) + +# The target shared library name +LIB_BUILD_DIR := $(BUILD_DIR)/lib +STATIC_NAME := $(LIB_BUILD_DIR)/lib$(PROJECT).a +DYNAMIC_NAME := $(LIB_BUILD_DIR)/lib$(PROJECT).so + +############################## +# Get all source files +############################## +# CXX_SRCS are the source files excluding the test ones. +CXX_SRCS := $(shell find src/$(PROJECT) ! -name "test_*.cpp" -name "*.cpp") +# CU_SRCS are the cuda source files +CU_SRCS := $(shell find src/$(PROJECT) ! -name "test_*.cu" -name "*.cu") +# TEST_SRCS are the test source files +TEST_MAIN_SRC := src/$(PROJECT)/test/test_caffe_main.cpp +TEST_SRCS := $(shell find src/$(PROJECT) -name "test_*.cpp") +TEST_SRCS := $(filter-out $(TEST_MAIN_SRC), $(TEST_SRCS)) +TEST_CU_SRCS := $(shell find src/$(PROJECT) -name "test_*.cu") +GTEST_SRC := src/gtest/gtest-all.cpp +# TOOL_SRCS are the source files for the tool binaries +TOOL_SRCS := $(shell find tools -name "*.cpp") +# EXAMPLE_SRCS are the source files for the example binaries +EXAMPLE_SRCS := $(shell find examples -name "*.cpp") +# BUILD_INCLUDE_DIR contains any generated header files we want to include. +BUILD_INCLUDE_DIR := $(BUILD_DIR)/src +# PROTO_SRCS are the protocol buffer definitions +PROTO_SRC_DIR := src/$(PROJECT)/proto +PROTO_SRCS := $(wildcard $(PROTO_SRC_DIR)/*.proto) +# PROTO_BUILD_DIR will contain the .cc and obj files generated from +# PROTO_SRCS; PROTO_BUILD_INCLUDE_DIR will contain the .h header files +PROTO_BUILD_DIR := $(BUILD_DIR)/$(PROTO_SRC_DIR) +PROTO_BUILD_INCLUDE_DIR := $(BUILD_INCLUDE_DIR)/$(PROJECT)/proto +# NONGEN_CXX_SRCS includes all source/header files except those generated +# automatically (e.g., by proto). +NONGEN_CXX_SRCS := $(shell find \ + src/$(PROJECT) \ + include/$(PROJECT) \ + python/$(PROJECT) \ + matlab/$(PROJECT) \ + examples \ + tools \ + -name "*.cpp" -or -name "*.hpp" -or -name "*.cu" -or -name "*.cuh") +LINT_SCRIPT := scripts/cpp_lint.py +LINT_OUTPUT_DIR := $(BUILD_DIR)/.lint +LINT_EXT := lint.txt +LINT_OUTPUTS := $(addsuffix .$(LINT_EXT), $(addprefix $(LINT_OUTPUT_DIR)/, $(NONGEN_CXX_SRCS))) +EMPTY_LINT_REPORT := $(BUILD_DIR)/.$(LINT_EXT) +NONEMPTY_LINT_REPORT := $(BUILD_DIR)/$(LINT_EXT) +# PY$(PROJECT)_SRC is the python wrapper for $(PROJECT) +PY$(PROJECT)_SRC := python/$(PROJECT)/_$(PROJECT).cpp +PY$(PROJECT)_HXX_SRC := python/$(PROJECT)/_$(PROJECT).hpp +PY$(PROJECT)_SO := python/$(PROJECT)/_$(PROJECT).so +# MAT$(PROJECT)_SRC is the matlab wrapper for $(PROJECT) +MAT$(PROJECT)_SRC := matlab/$(PROJECT)/mat$(PROJECT).cpp +ifneq ($(MATLAB_DIR),) + MAT_SO_EXT := $(shell $(MATLAB_DIR)/bin/mexext) +endif +MAT$(PROJECT)_SO := matlab/$(PROJECT)/$(PROJECT).$(MAT_SO_EXT) + +############################## +# Derive generated files +############################## +# The generated files for protocol buffers +PROTO_GEN_HEADER_SRCS := $(addprefix $(PROTO_BUILD_DIR)/, \ + $(notdir ${PROTO_SRCS:.proto=.pb.h})) +PROTO_GEN_HEADER := $(addprefix $(PROTO_BUILD_INCLUDE_DIR)/, \ + $(notdir ${PROTO_SRCS:.proto=.pb.h})) +PROTO_GEN_CC := $(addprefix $(BUILD_DIR)/, ${PROTO_SRCS:.proto=.pb.cc}) +PY_PROTO_BUILD_DIR := python/$(PROJECT)/proto +PY_PROTO_INIT := python/$(PROJECT)/proto/__init__.py +PROTO_GEN_PY := $(foreach file,${PROTO_SRCS:.proto=_pb2.py}, \ + $(PY_PROTO_BUILD_DIR)/$(notdir $(file))) +# The objects corresponding to the source files +# These objects will be linked into the final shared library, so we +# exclude the tool, example, and test objects. +CXX_OBJS := $(addprefix $(BUILD_DIR)/, ${CXX_SRCS:.cpp=.o}) +CU_OBJS := $(addprefix $(BUILD_DIR)/cuda/, ${CU_SRCS:.cu=.o}) +PROTO_OBJS := ${PROTO_GEN_CC:.cc=.o} +OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS) +# tool, example, and test objects +TOOL_OBJS := $(addprefix $(BUILD_DIR)/, ${TOOL_SRCS:.cpp=.o}) +TOOL_BUILD_DIR := $(BUILD_DIR)/tools +TEST_CXX_BUILD_DIR := $(BUILD_DIR)/src/$(PROJECT)/test +TEST_CU_BUILD_DIR := $(BUILD_DIR)/cuda/src/$(PROJECT)/test +TEST_CXX_OBJS := $(addprefix $(BUILD_DIR)/, ${TEST_SRCS:.cpp=.o}) +TEST_CU_OBJS := $(addprefix $(BUILD_DIR)/cuda/, ${TEST_CU_SRCS:.cu=.o}) +TEST_OBJS := $(TEST_CXX_OBJS) $(TEST_CU_OBJS) +GTEST_OBJ := $(addprefix $(BUILD_DIR)/, ${GTEST_SRC:.cpp=.o}) +EXAMPLE_OBJS := $(addprefix $(BUILD_DIR)/, ${EXAMPLE_SRCS:.cpp=.o}) +# Output files for automatic dependency generation +DEPS := ${CXX_OBJS:.o=.d} ${CU_OBJS:.o=.d} ${TEST_CXX_OBJS:.o=.d} \ + ${TEST_CU_OBJS:.o=.d} +# tool, example, and test bins +TOOL_BINS := ${TOOL_OBJS:.o=.bin} +EXAMPLE_BINS := ${EXAMPLE_OBJS:.o=.bin} +# symlinks to tool bins without the ".bin" extension +TOOL_BIN_LINKS := ${TOOL_BINS:.bin=} +# Put the test binaries in build/test for convenience. +TEST_BIN_DIR := $(BUILD_DIR)/test +TEST_CU_BINS := $(addsuffix .testbin,$(addprefix $(TEST_BIN_DIR)/, \ + $(foreach obj,$(TEST_CU_OBJS),$(basename $(notdir $(obj)))))) +TEST_CXX_BINS := $(addsuffix .testbin,$(addprefix $(TEST_BIN_DIR)/, \ + $(foreach obj,$(TEST_CXX_OBJS),$(basename $(notdir $(obj)))))) +TEST_BINS := $(TEST_CXX_BINS) $(TEST_CU_BINS) +# TEST_ALL_BIN is the test binary that links caffe statically. +TEST_ALL_BIN := $(TEST_BIN_DIR)/test_all.testbin +# TEST_ALL_DYNINK_BIN is the test binary that links caffe as a dynamic library. +TEST_ALL_DYNLINK_BIN := $(TEST_BIN_DIR)/test_all_dynamic_link.testbin + +############################## +# Derive compiler warning dump locations +############################## +WARNS_EXT := warnings.txt +CXX_WARNS := $(addprefix $(BUILD_DIR)/, ${CXX_SRCS:.cpp=.o.$(WARNS_EXT)}) +CU_WARNS := $(addprefix $(BUILD_DIR)/cuda/, ${CU_SRCS:.cu=.o.$(WARNS_EXT)}) +TOOL_WARNS := $(addprefix $(BUILD_DIR)/, ${TOOL_SRCS:.cpp=.o.$(WARNS_EXT)}) +EXAMPLE_WARNS := $(addprefix $(BUILD_DIR)/, ${EXAMPLE_SRCS:.cpp=.o.$(WARNS_EXT)}) +TEST_WARNS := $(addprefix $(BUILD_DIR)/, ${TEST_SRCS:.cpp=.o.$(WARNS_EXT)}) +TEST_CU_WARNS := $(addprefix $(BUILD_DIR)/cuda/, ${TEST_CU_SRCS:.cu=.o.$(WARNS_EXT)}) +ALL_CXX_WARNS := $(CXX_WARNS) $(TOOL_WARNS) $(EXAMPLE_WARNS) $(TEST_WARNS) +ALL_CU_WARNS := $(CU_WARNS) $(TEST_CU_WARNS) +ALL_WARNS := $(ALL_CXX_WARNS) $(ALL_CU_WARNS) + +EMPTY_WARN_REPORT := $(BUILD_DIR)/.$(WARNS_EXT) +NONEMPTY_WARN_REPORT := $(BUILD_DIR)/$(WARNS_EXT) + +############################## +# Derive include and lib directories +############################## +CUDA_INCLUDE_DIR := $(CUDA_DIR)/include + +CUDA_LIB_DIR := +# add /lib64 only if it exists +ifneq ("$(wildcard $(CUDA_DIR)/lib64)","") + CUDA_LIB_DIR += $(CUDA_DIR)/lib64 +endif +CUDA_LIB_DIR += $(CUDA_DIR)/lib + +INCLUDE_DIRS += $(BUILD_INCLUDE_DIR) ./src ./include +ifneq ($(CPU_ONLY), 1) + INCLUDE_DIRS += $(CUDA_INCLUDE_DIR) + LIBRARY_DIRS += $(CUDA_LIB_DIR) + LIBRARIES := cudart cublas curand +endif +LIBRARIES += glog gflags protobuf leveldb snappy \ + lmdb boost_system hdf5_hl hdf5 m \ + opencv_core opencv_highgui opencv_imgproc +PYTHON_LIBRARIES := boost_python python2.7 +WARNINGS := -Wall -Wno-sign-compare + +############################## +# Set build directories +############################## + +DISTRIBUTE_SUBDIRS := $(DISTRIBUTE_DIR)/bin $(DISTRIBUTE_DIR)/lib +DIST_ALIASES := dist +ifneq ($(strip $(DISTRIBUTE_DIR)),distribute) + DIST_ALIASES += distribute +endif + +ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS)) \ + $(addprefix $(BUILD_DIR)/cuda/, $(SRC_DIRS)) \ + $(LIB_BUILD_DIR) $(TEST_BIN_DIR) $(PY_PROTO_BUILD_DIR) $(LINT_OUTPUT_DIR) \ + $(DISTRIBUTE_SUBDIRS) $(PROTO_BUILD_INCLUDE_DIR)) + +############################## +# Set directory for Doxygen-generated documentation +############################## +DOXYGEN_CONFIG_FILE ?= ./.Doxyfile +# should be the same as OUTPUT_DIRECTORY in the .Doxyfile +DOXYGEN_OUTPUT_DIR ?= ./doxygen +DOXYGEN_COMMAND ?= doxygen +# All the files that might have Doxygen documentation. +DOXYGEN_SOURCES := $(shell find \ + src/$(PROJECT) \ + include/$(PROJECT) \ + python/ \ + matlab/ \ + examples \ + tools \ + -name "*.cpp" -or -name "*.hpp" -or -name "*.cu" -or -name "*.cuh" -or \ + -name "*.py" -or -name "*.m") +DOXYGEN_SOURCES += $(DOXYGEN_CONFIG_FILE) + + +############################## +# Configure build +############################## + +# Determine platform +UNAME := $(shell uname -s) +ifeq ($(UNAME), Linux) + LINUX := 1 +else ifeq ($(UNAME), Darwin) + OSX := 1 +endif + +# Linux +ifeq ($(LINUX), 1) + CXX ?= /usr/bin/g++ + GCCVERSION := $(shell $(CXX) -dumpversion | cut -f1,2 -d.) + # older versions of gcc are too dumb to build boost with -Wuninitalized + ifeq ($(shell echo $(GCCVERSION) \< 4.6 | bc), 1) + WARNINGS += -Wno-uninitialized + endif + # boost::thread is reasonably called boost_thread (compare OS X) + # We will also explicitly add stdc++ to the link target. + LIBRARIES += boost_thread stdc++ +endif + +# OS X: +# clang++ instead of g++ +# libstdc++ instead of libc++ for CUDA compatibility on 10.9 +ifeq ($(OSX), 1) + CXX := /usr/bin/clang++ + # clang throws this warning for cuda headers + WARNINGS += -Wno-unneeded-internal-declaration + ifneq ($(findstring 10.9, $(shell sw_vers -productVersion)),) + CXXFLAGS += -stdlib=libstdc++ + LINKFLAGS += -stdlib=libstdc++ + endif + # boost::thread is called boost_thread-mt to mark multithreading on OS X + LIBRARIES += boost_thread-mt + NVCCFLAGS += -DOSX +endif + +# Custom compiler +ifdef CUSTOM_CXX + CXX := $(CUSTOM_CXX) +endif + +# Static linking +ifneq (,$(findstring clang++,$(CXX))) + STATIC_LINK_COMMAND := -Wl,-force_load $(STATIC_NAME) +else ifneq (,$(findstring g++,$(CXX))) + STATIC_LINK_COMMAND := -Wl,--whole-archive $(STATIC_NAME) -Wl,--no-whole-archive +else + $(error Cannot static link with the $(CXX) compiler.) +endif + +# Debugging +ifeq ($(DEBUG), 1) + COMMON_FLAGS += -DDEBUG -g -O0 + NVCCFLAGS += -G +else + COMMON_FLAGS += -DNDEBUG -O2 +endif + +# cuDNN acceleration configuration. +ifeq ($(USE_CUDNN), 1) + LIBRARIES += cudnn + COMMON_FLAGS += -DUSE_CUDNN +endif + +# CPU-only configuration +ifeq ($(CPU_ONLY), 1) + OBJS := $(PROTO_OBJS) $(CXX_OBJS) + TEST_OBJS := $(TEST_CXX_OBJS) + TEST_BINS := $(TEST_CXX_BINS) + ALL_WARNS := $(ALL_CXX_WARNS) + TEST_FILTER := --gtest_filter="-*GPU*" + COMMON_FLAGS += -DCPU_ONLY +endif + +# BLAS configuration (default = ATLAS) +BLAS ?= atlas +ifeq ($(BLAS), mkl) + # MKL + LIBRARIES += mkl_rt + COMMON_FLAGS += -DUSE_MKL + MKL_DIR ?= /opt/intel/mkl + BLAS_INCLUDE ?= $(MKL_DIR)/include + BLAS_LIB ?= $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64 +else ifeq ($(BLAS), open) + # OpenBLAS + LIBRARIES += openblas +else + # ATLAS + ifeq ($(LINUX), 1) + ifeq ($(BLAS), atlas) + # Linux simply has cblas and atlas + LIBRARIES += cblas atlas + endif + else ifeq ($(OSX), 1) + # OS X packages atlas as the vecLib framework + BLAS_INCLUDE ?= /System/Library/Frameworks/vecLib.framework/Versions/Current/Headers/ + LIBRARIES += cblas + LDFLAGS += -framework vecLib + endif +endif +INCLUDE_DIRS += $(BLAS_INCLUDE) +LIBRARY_DIRS += $(BLAS_LIB) + +LIBRARY_DIRS += $(LIB_BUILD_DIR) + +# Automatic dependency generation (nvcc is handled separately) +CXXFLAGS += -MMD -MP + +# Complete build flags. +COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) +CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS) +NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) +# mex may invoke an older gcc that is too liberal with -Wuninitalized +MATLAB_CXXFLAGS := $(CXXFLAGS) -Wno-uninitialized +LINKFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS) + +USE_PKG_CONFIG ?= 0 +ifeq ($(USE_PKG_CONFIG), 1) + PKG_CONFIG := $(shell pkg-config opencv --libs) +else + PKG_CONFIG := +endif +LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) $(PKG_CONFIG) \ + $(foreach library,$(LIBRARIES),-l$(library)) +PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library)) + +# 'superclean' target recursively* deletes all files ending with an extension +# in $(SUPERCLEAN_EXTS) below. This may be useful if you've built older +# versions of Caffe that do not place all generated files in a location known +# to the 'clean' target. +# +# 'supercleanlist' will list the files to be deleted by make superclean. +# +# * Recursive with the exception that symbolic links are never followed, per the +# default behavior of 'find'. +SUPERCLEAN_EXTS := .so .a .o .bin .testbin .pb.cc .pb.h _pb2.py .cuo + +############################## +# Define build targets +############################## +.PHONY: all test clean docs linecount lint lintclean tools examples $(DIST_ALIASES) \ + py mat py$(PROJECT) mat$(PROJECT) proto runtest \ + superclean supercleanlist supercleanfiles warn everything + +all: $(STATIC_NAME) $(DYNAMIC_NAME) tools examples + +everything: all py$(PROJECT) mat$(PROJECT) test warn lint runtest + +linecount: + cloc --read-lang-def=$(PROJECT).cloc \ + src/$(PROJECT) include/$(PROJECT) tools examples \ + python matlab + +lint: $(EMPTY_LINT_REPORT) + +lintclean: + @ $(RM) -r $(LINT_OUTPUT_DIR) $(EMPTY_LINT_REPORT) $(NONEMPTY_LINT_REPORT) + +docs: $(DOXYGEN_OUTPUT_DIR) + @ cd ./docs ; ln -sfn ../$(DOXYGEN_OUTPUT_DIR)/html doxygen + +$(DOXYGEN_OUTPUT_DIR): $(DOXYGEN_CONFIG_FILE) $(DOXYGEN_SOURCES) + $(DOXYGEN_COMMAND) $(DOXYGEN_CONFIG_FILE) + +$(EMPTY_LINT_REPORT): $(LINT_OUTPUTS) | $(BUILD_DIR) + @ cat $(LINT_OUTPUTS) > $@ + @ if [ -s "$@" ]; then \ + cat $@; \ + mv $@ $(NONEMPTY_LINT_REPORT); \ + echo "Found one or more lint errors."; \ + exit 1; \ + fi; \ + $(RM) $(NONEMPTY_LINT_REPORT); \ + echo "No lint errors!"; + +$(LINT_OUTPUTS): $(LINT_OUTPUT_DIR)/%.lint.txt : % $(LINT_SCRIPT) | $(LINT_OUTPUT_DIR) + @ mkdir -p $(dir $@) + @ python $(LINT_SCRIPT) $< 2>&1 \ + | grep -v "^Done processing " \ + | grep -v "^Total errors found: 0" \ + > $@ \ + || true + +test: $(TEST_ALL_BIN) $(TEST_ALL_DYNLINK_BIN) $(TEST_BINS) + +tools: $(TOOL_BINS) $(TOOL_BIN_LINKS) + +examples: $(EXAMPLE_BINS) + +py$(PROJECT): py + +py: $(PY$(PROJECT)_SO) $(PROTO_GEN_PY) + +$(PY$(PROJECT)_SO): $(PY$(PROJECT)_SRC) $(STATIC_NAME) $(PY$(PROJECT)_HXX_SRC) + @ echo CXX $< + $(Q)$(CXX) -shared -o $@ $(PY$(PROJECT)_SRC) \ + $(STATIC_LINK_COMMAND) $(LINKFLAGS) $(PYTHON_LDFLAGS) + +mat$(PROJECT): mat + +mat: $(MAT$(PROJECT)_SO) + +$(MAT$(PROJECT)_SO): $(MAT$(PROJECT)_SRC) $(STATIC_NAME) + @ if [ -z "$(MATLAB_DIR)" ]; then \ + echo "MATLAB_DIR must be specified in $(CONFIG_FILE)" \ + "to build mat$(PROJECT)."; \ + exit 1; \ + fi + @ echo MEX $< + $(Q)$(MATLAB_DIR)/bin/mex $(MAT$(PROJECT)_SRC) \ + CXX="$(CXX)" \ + CXXFLAGS="\$$CXXFLAGS $(MATLAB_CXXFLAGS)" \ + CXXLIBS="\$$CXXLIBS $(STATIC_LINK_COMMAND) $(LDFLAGS)" -output $@ + +runtest: $(TEST_ALL_BIN) $(TEST_ALL_DYNLINK_BIN) + $(TEST_ALL_BIN) $(TEST_GPUID) --gtest_shuffle $(TEST_FILTER) && \ + $(TEST_ALL_DYNLINK_BIN) $(TEST_GPUID) --gtest_shuffle $(TEST_FILTER) + +warn: $(EMPTY_WARN_REPORT) + +$(EMPTY_WARN_REPORT): $(ALL_WARNS) | $(BUILD_DIR) + @ cat $(ALL_WARNS) > $@ + @ if [ -s "$@" ]; then \ + cat $@; \ + mv $@ $(NONEMPTY_WARN_REPORT); \ + echo "Compiler produced one or more warnings."; \ + exit 1; \ + fi; \ + $(RM) $(NONEMPTY_WARN_REPORT); \ + echo "No compiler warnings!"; + +$(ALL_WARNS): %.o.$(WARNS_EXT) : %.o + +$(BUILD_DIR_LINK): $(BUILD_DIR)/.linked + +# Create a target ".linked" in this BUILD_DIR to tell Make that the "build" link +# is currently correct, then delete the one in the OTHER_BUILD_DIR in case it +# exists and $(DEBUG) is toggled later. +$(BUILD_DIR)/.linked: + @ mkdir -p $(BUILD_DIR) + @ $(RM) $(OTHER_BUILD_DIR)/.linked + @ $(RM) -r $(BUILD_DIR_LINK) + @ ln -s $(BUILD_DIR) $(BUILD_DIR_LINK) + @ touch $@ + +$(ALL_BUILD_DIRS): | $(BUILD_DIR_LINK) + @ mkdir -p $@ + +$(DYNAMIC_NAME): $(OBJS) | $(LIB_BUILD_DIR) + @ echo LD $< + $(Q)$(CXX) -shared -o $@ $(OBJS) $(LINKFLAGS) $(LDFLAGS) + +$(STATIC_NAME): $(OBJS) | $(LIB_BUILD_DIR) + @ echo AR $< + $(Q)ar rcs $@ $(OBJS) + +$(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS) + @ echo CXX $< + $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 2> $@.$(WARNS_EXT) \ + || (cat $@.$(WARNS_EXT); exit 1) + @ cat $@.$(WARNS_EXT) + +$(PROTO_BUILD_DIR)/%.pb.o: $(PROTO_BUILD_DIR)/%.pb.cc $(PROTO_GEN_HEADER) \ + | $(PROTO_BUILD_DIR) + @ echo CXX $< + $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 2> $@.$(WARNS_EXT) \ + || (cat $@.$(WARNS_EXT); exit 1) + @ cat $@.$(WARNS_EXT) + +$(BUILD_DIR)/cuda/%.o: %.cu | $(ALL_BUILD_DIRS) + @ echo NVCC $< + $(Q)$(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ + -odir $(@D) + $(Q)$(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 2> $@.$(WARNS_EXT) \ + || (cat $@.$(WARNS_EXT); exit 1) + @ cat $@.$(WARNS_EXT) + +$(TEST_ALL_BIN): $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJ) $(STATIC_NAME) \ + | $(TEST_BIN_DIR) + @ echo CXX/LD -o $@ $< + $(Q)$(CXX) $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJ) $(STATIC_LINK_COMMAND) \ + -o $@ $(LINKFLAGS) $(LDFLAGS) + +$(TEST_ALL_DYNLINK_BIN): $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJ) $(DYNAMIC_NAME) \ + | $(TEST_BIN_DIR) + @ echo CXX/LD -o $@ $< + $(Q)$(CXX) $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJ) \ + -o $@ $(LINKFLAGS) $(LDFLAGS) -l$(PROJECT) -Wl,-rpath,$(LIB_BUILD_DIR) + +$(TEST_CU_BINS): $(TEST_BIN_DIR)/%.testbin: $(TEST_CU_BUILD_DIR)/%.o \ + $(GTEST_OBJ) $(STATIC_NAME) | $(TEST_BIN_DIR) + @ echo LD $< + $(Q)$(CXX) $(TEST_MAIN_SRC) $< $(GTEST_OBJ) $(STATIC_LINK_COMMAND) \ + -o $@ $(LINKFLAGS) $(LDFLAGS) + +$(TEST_CXX_BINS): $(TEST_BIN_DIR)/%.testbin: $(TEST_CXX_BUILD_DIR)/%.o \ + $(GTEST_OBJ) $(STATIC_NAME) | $(TEST_BIN_DIR) + @ echo LD $< + $(Q)$(CXX) $(TEST_MAIN_SRC) $< $(GTEST_OBJ) $(STATIC_LINK_COMMAND) \ + -o $@ $(LINKFLAGS) $(LDFLAGS) + +# Target for extension-less symlinks to tool binaries with extension '*.bin'. +$(TOOL_BUILD_DIR)/%: $(TOOL_BUILD_DIR)/%.bin | $(TOOL_BUILD_DIR) + @ $(RM) $@ + @ ln -s $(abspath $<) $@ + +$(TOOL_BINS) $(EXAMPLE_BINS): %.bin : %.o $(STATIC_NAME) + @ echo LD $< + $(Q)$(CXX) $< $(STATIC_LINK_COMMAND) -o $@ $(LINKFLAGS) $(LDFLAGS) + +proto: $(PROTO_GEN_CC) $(PROTO_GEN_HEADER) + +$(PROTO_BUILD_DIR)/%.pb.cc $(PROTO_BUILD_DIR)/%.pb.h : \ + $(PROTO_SRC_DIR)/%.proto | $(PROTO_BUILD_DIR) + @ echo PROTOC $< + $(Q)protoc --proto_path=$(PROTO_SRC_DIR) --cpp_out=$(PROTO_BUILD_DIR) $< + +$(PY_PROTO_BUILD_DIR)/%_pb2.py : $(PROTO_SRC_DIR)/%.proto \ + $(PY_PROTO_INIT) | $(PY_PROTO_BUILD_DIR) + @ echo PROTOC \(python\) $< + $(Q)protoc --proto_path=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $< + +$(PY_PROTO_INIT): | $(PY_PROTO_BUILD_DIR) + touch $(PY_PROTO_INIT) + +clean: + @- $(RM) -rf $(ALL_BUILD_DIRS) + @- $(RM) -rf $(OTHER_BUILD_DIR) + @- $(RM) -rf $(BUILD_DIR_LINK) + @- $(RM) -rf $(DISTRIBUTE_DIR) + @- $(RM) $(PY$(PROJECT)_SO) + @- $(RM) $(MAT$(PROJECT)_SO) + +supercleanfiles: + $(eval SUPERCLEAN_FILES := $(strip \ + $(foreach ext,$(SUPERCLEAN_EXTS), $(shell find . -name '*$(ext)' \ + -not -path './data/*')))) + +supercleanlist: supercleanfiles + @ \ + if [ -z "$(SUPERCLEAN_FILES)" ]; then \ + echo "No generated files found."; \ + else \ + echo $(SUPERCLEAN_FILES) | tr ' ' '\n'; \ + fi + +superclean: clean supercleanfiles + @ \ + if [ -z "$(SUPERCLEAN_FILES)" ]; then \ + echo "No generated files found."; \ + else \ + echo "Deleting the following generated files:"; \ + echo $(SUPERCLEAN_FILES) | tr ' ' '\n'; \ + $(RM) $(SUPERCLEAN_FILES); \ + fi + +$(DIST_ALIASES): $(DISTRIBUTE_DIR) + +$(DISTRIBUTE_DIR): all py | $(DISTRIBUTE_SUBDIRS) + # add include + cp -r include $(DISTRIBUTE_DIR)/ + mkdir -p $(DISTRIBUTE_DIR)/include/caffe/proto + cp $(PROTO_GEN_HEADER_SRCS) $(DISTRIBUTE_DIR)/include/caffe/proto + # add tool and example binaries + cp $(TOOL_BINS) $(DISTRIBUTE_DIR)/bin + cp $(EXAMPLE_BINS) $(DISTRIBUTE_DIR)/bin + # add libraries + cp $(STATIC_NAME) $(DISTRIBUTE_DIR)/lib + cp $(DYNAMIC_NAME) $(DISTRIBUTE_DIR)/lib + # add python - it's not the standard way, indeed... + cp -r python $(DISTRIBUTE_DIR)/python + +-include $(DEPS) diff --git a/caffe-crfrnn/Makefile.config b/caffe-crfrnn/Makefile.config new file mode 100755 index 00000000..e38918e3 --- /dev/null +++ b/caffe-crfrnn/Makefile.config @@ -0,0 +1,78 @@ +## Refer to http://caffe.berkeleyvision.org/installation.html +# Contributions simplifying and improving our build system are welcome! + +# cuDNN acceleration switch (uncomment to build with cuDNN). +USE_CUDNN := 1 + +# CPU-only switch (uncomment to build without GPU support). +# CPU_ONLY := 1 + +# To customize your choice of compiler, uncomment and set the following. +# N.B. the default for Linux is g++ and the default for OSX is clang++ +# CUSTOM_CXX := g++ + +# CUDA directory contains bin/ and lib/ directories that we need. +CUDA_DIR := /usr/local/cuda +# On Ubuntu 14.04, if cuda tools are installed via +# "sudo apt-get install nvidia-cuda-toolkit" then use this instead: +# CUDA_DIR := /usr + +# CUDA architecture setting: going with all of them (up to CUDA 5.5 compatible). +# For the latest architecture, you need to install CUDA >= 6.0 and uncomment +# the *_50 lines below. +CUDA_ARCH := -gencode arch=compute_30,code=sm_30 \ + -gencode arch=compute_35,code=sm_35 + #-gencode arch=compute_50,code=sm_50 \ + #-gencode arch=compute_50,code=compute_50 + +# BLAS choice: +# atlas for ATLAS (default) +# mkl for MKL +# open for OpenBlas +BLAS := mkl +# Custom (MKL/ATLAS/OpenBLAS) include and lib directories. +# Leave commented to accept the defaults for your choice of BLAS +# (which should work)! +BLAS_INCLUDE := /home/bittnt/intel/compilers_and_libraries_2016.0.109/linux/mkl/include +BLAS_LIB := /home/bittnt/intel/compilers_and_libraries_2016.0.109/linux/mkl/lib/intel64 + +# This is required only if you will compile the matlab interface. +# MATLAB directory should contain the mex binary in /bin. +MATLAB_DIR := /usr/local/MATLAB/R2015a +# MATLAB_DIR := /Applications/MATLAB_R2012b.app + +# NOTE: this is required only if you will compile the python interface. +# We need to be able to find Python.h and numpy/arrayobject.h. +#PYTHON_INCLUDE := /usr/include/python2.7 \ +# /usr/lib/python2.7/dist-packages/numpy/core/include +# Anaconda Python distribution is quite popular. Include path: +# Verify anaconda location, sometimes it's in root. +ANACONDA_HOME := $(HOME)/anaconda +PYTHON_INCLUDE := $(ANACONDA_HOME)/include \ + $(ANACONDA_HOME)/include/python2.7 \ + $(ANACONDA_HOME)/lib/python2.7/site-packages/numpy/core/include \ + +# We need to be able to find libpythonX.X.so or .dylib. +#PYTHON_LIB := /usr/lib +PYTHON_LIB := $(ANACONDA_HOME)/lib + +# Whatever else you find you need goes here. +INCLUDE_DIRS := $(PYTHON_INCLUDE) /home/bittnt/common/include /home/bittnt/crf-rnn/cuda/include /usr/local/cuda/include +LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib /home/bittnt/common/lib /home/bittnt/crf-rnn/cuda/lib64 /usr/local/cuda/lib +#LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib /home/bittnt/common/lib /home/bittnt/Documents/ConvMean/caffe-fcn-sadeep/cudnnv2rc3 + +# Uncomment to use `pkg-config` to specify OpenCV library paths. +# (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.) +# USE_PKG_CONFIG := 1 + +BUILD_DIR := build +DISTRIBUTE_DIR := distribute + +# Uncomment for debugging. Does not work on OSX due to https://github.com/BVLC/caffe/issues/171 +#DEBUG := 1 + +# The ID of the GPU that 'make runtest' will use to run unit tests. +TEST_GPUID := 0 + +# enable pretty build (comment to see full commands) +Q ?= @ diff --git a/caffe-crfrnn/Makefile.config.example b/caffe-crfrnn/Makefile.config.example new file mode 100644 index 00000000..0c996038 --- /dev/null +++ b/caffe-crfrnn/Makefile.config.example @@ -0,0 +1,79 @@ +## Refer to http://caffe.berkeleyvision.org/installation.html +# Contributions simplifying and improving our build system are welcome! + +# cuDNN acceleration switch (uncomment to build with cuDNN). +# USE_CUDNN := 1 + +# CPU-only switch (uncomment to build without GPU support). +# CPU_ONLY := 1 + +# To customize your choice of compiler, uncomment and set the following. +# N.B. the default for Linux is g++ and the default for OSX is clang++ +# CUSTOM_CXX := g++ + +# CUDA directory contains bin/ and lib/ directories that we need. +CUDA_DIR := /usr/local/cuda +# On Ubuntu 14.04, if cuda tools are installed via +# "sudo apt-get install nvidia-cuda-toolkit" then use this instead: +# CUDA_DIR := /usr + +# CUDA architecture setting: going with all of them (up to CUDA 5.5 compatible). +# For the latest architecture, you need to install CUDA >= 6.0 and uncomment +# the *_50 lines below. +CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \ + -gencode arch=compute_20,code=sm_21 \ + -gencode arch=compute_30,code=sm_30 \ + -gencode arch=compute_35,code=sm_35 \ + #-gencode arch=compute_50,code=sm_50 \ + #-gencode arch=compute_50,code=compute_50 + +# BLAS choice: +# atlas for ATLAS (default) +# mkl for MKL +# open for OpenBlas +BLAS := atlas +# Custom (MKL/ATLAS/OpenBLAS) include and lib directories. +# Leave commented to accept the defaults for your choice of BLAS +# (which should work)! +# BLAS_INCLUDE := /path/to/your/blas +# BLAS_LIB := /path/to/your/blas + +# This is required only if you will compile the matlab interface. +# MATLAB directory should contain the mex binary in /bin. +# MATLAB_DIR := /usr/local +# MATLAB_DIR := /Applications/MATLAB_R2012b.app + +# NOTE: this is required only if you will compile the python interface. +# We need to be able to find Python.h and numpy/arrayobject.h. +PYTHON_INCLUDE := /usr/include/python2.7 \ + /usr/lib/python2.7/dist-packages/numpy/core/include +# Anaconda Python distribution is quite popular. Include path: +# Verify anaconda location, sometimes it's in root. +# ANACONDA_HOME := $(HOME)/anaconda +# PYTHON_INCLUDE := $(ANACONDA_HOME)/include \ + # $(ANACONDA_HOME)/include/python2.7 \ + # $(ANACONDA_HOME)/lib/python2.7/site-packages/numpy/core/include \ + +# We need to be able to find libpythonX.X.so or .dylib. +PYTHON_LIB := /usr/lib +# PYTHON_LIB := $(ANACONDA_HOME)/lib + +# Whatever else you find you need goes here. +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib + +# Uncomment to use `pkg-config` to specify OpenCV library paths. +# (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.) +# USE_PKG_CONFIG := 1 + +BUILD_DIR := build +DISTRIBUTE_DIR := distribute + +# Uncomment for debugging. Does not work on OSX due to https://github.com/BVLC/caffe/issues/171 +# DEBUG := 1 + +# The ID of the GPU that 'make runtest' will use to run unit tests. +TEST_GPUID := 0 + +# enable pretty build (comment to see full commands) +Q ?= @ diff --git a/caffe-crfrnn/README.md b/caffe-crfrnn/README.md new file mode 100644 index 00000000..5b5cb1f4 --- /dev/null +++ b/caffe-crfrnn/README.md @@ -0,0 +1,5 @@ +This is Caffe with several unmerged PRs. + +Everything here is subject to change, including the history of this branch. + +See `future.sh` for details. diff --git a/caffe-crfrnn/caffe.cloc b/caffe-crfrnn/caffe.cloc new file mode 100644 index 00000000..a36ab619 --- /dev/null +++ b/caffe-crfrnn/caffe.cloc @@ -0,0 +1,53 @@ +Bourne Shell + filter remove_matches ^\s*# + filter remove_inline #.*$ + extension sh + script_exe sh +C + filter remove_matches ^\s*// + filter call_regexp_common C + filter remove_inline //.*$ + extension c + extension ec + extension pgc +C++ + filter remove_matches ^\s*// + filter remove_inline //.*$ + filter call_regexp_common C + extension C + extension cc + extension cpp + extension cxx + extension pcc +C/C++ Header + filter remove_matches ^\s*// + filter call_regexp_common C + filter remove_inline //.*$ + extension H + extension h + extension hh + extension hpp +CUDA + filter remove_matches ^\s*// + filter remove_inline //.*$ + filter call_regexp_common C + extension cu +Python + filter remove_matches ^\s*# + filter docstring_to_C + filter call_regexp_common C + filter remove_inline #.*$ + extension py +make + filter remove_matches ^\s*# + filter remove_inline #.*$ + extension Gnumakefile + extension Makefile + extension am + extension gnumakefile + extension makefile + filename Gnumakefile + filename Makefile + filename gnumakefile + filename makefile + script_exe make diff --git a/caffe-crfrnn/cmake/ConfigGen.cmake b/caffe-crfrnn/cmake/ConfigGen.cmake new file mode 100644 index 00000000..566d6ca0 --- /dev/null +++ b/caffe-crfrnn/cmake/ConfigGen.cmake @@ -0,0 +1,104 @@ + +################################################################################################ +# Helper function to fetch caffe includes which will be passed to dependent projects +# Usage: +# caffe_get_current_includes() +function(caffe_get_current_includes includes_variable) + get_property(current_includes DIRECTORY PROPERTY INCLUDE_DIRECTORIES) + caffe_convert_absolute_paths(current_includes) + + # remove at most one ${PROJECT_BINARY_DIR} include added for caffe_config.h + list(FIND current_includes ${PROJECT_BINARY_DIR} __index) + list(REMOVE_AT current_includes ${__index}) + + # removing numpy includes (since not required for client libs) + set(__toremove "") + foreach(__i ${current_includes}) + if(${__i} MATCHES "python") + list(APPEND __toremove ${__i}) + endif() + endforeach() + if(__toremove) + list(REMOVE_ITEM current_includes ${__toremove}) + endif() + + caffe_list_unique(current_includes) + set(${includes_variable} ${current_includes} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Helper function to get all list items that begin with given prefix +# Usage: +# caffe_get_items_with_prefix( ) +function(caffe_get_items_with_prefix prefix list_variable output_variable) + set(__result "") + foreach(__e ${${list_variable}}) + if(__e MATCHES "^${prefix}.*") + list(APPEND __result ${__e}) + endif() + endforeach() + set(${output_variable} ${__result} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Function for generation Caffe build- and install- tree export config files +# Usage: +# caffe_generate_export_configs() +function(caffe_generate_export_configs) + set(install_cmake_suffix "share/Caffe") + + # ---[ Configure build-tree CaffeConfig.cmake file ]--- + caffe_get_current_includes(Caffe_INCLUDE_DIRS) + + set(Caffe_DEFINITIONS "") + if(NOT HAVE_CUDA) + set(HAVE_CUDA FALSE) + list(APPEND Caffe_DEFINITIONS -DCPU_ONLY) + endif() + + if(NOT HAVE_CUDNN) + set(HAVE_CUDNN FALSE) + else() + list(APPEND DEFINITIONS -DUSE_CUDNN) + endif() + + if(BLAS STREQUAL "MKL" OR BLAS STREQUAL "mkl") + list(APPEND Caffe_DEFINITIONS -DUSE_MKL) + endif() + + configure_file("cmake/Templates/CaffeConfig.cmake.in" "${PROJECT_BINARY_DIR}/CaffeConfig.cmake" @ONLY) + + # Add targets to the build-tree export set + export(TARGETS caffe proto FILE "${PROJECT_BINARY_DIR}/CaffeTargets.cmake") + export(PACKAGE Caffe) + + # ---[ Configure install-tree CaffeConfig.cmake file ]--- + + # remove source and build dir includes + caffe_get_items_with_prefix(${PROJECT_SOURCE_DIR} Caffe_INCLUDE_DIRS __insource) + caffe_get_items_with_prefix(${PROJECT_BINARY_DIR} Caffe_INCLUDE_DIRS __inbinary) + list(REMOVE_ITEM Caffe_INCLUDE_DIRS ${__insource} ${__inbinary}) + + # add `install` include folder + set(lines + "get_filename_component(__caffe_include \"\${Caffe_CMAKE_DIR}/../../include\" ABSOLUTE)\n" + "list(APPEND Caffe_INCLUDE_DIRS \${__caffe_include})\n" + "unset(__caffe_include)\n") + string(REPLACE ";" "" Caffe_INSTALL_INCLUDE_DIR_APPEND_COMMAND ${lines}) + + configure_file("cmake/Templates/CaffeConfig.cmake.in" "${PROJECT_BINARY_DIR}/cmake/CaffeConfig.cmake" @ONLY) + + # Install the CaffeConfig.cmake and export set to use with install-tree + install(FILES "${PROJECT_BINARY_DIR}/cmake/CaffeConfig.cmake" DESTINATION ${install_cmake_suffix}) + install(EXPORT CaffeTargets DESTINATION ${install_cmake_suffix}) + + # ---[ Configure and install version file ]--- + + # TODO: Lines below are commented because Caffe does't declare its version in headers. + # When the declarations are added, modify `caffe_extract_caffe_version()` macro and uncomment + + # configure_file(cmake/Templates/CaffeConfigVersion.cmake.in "${PROJECT_BINARY_DIR}/CaffeConfigVersion.cmake" @ONLY) + # install(FILES "${PROJECT_BINARY_DIR}/CaffeConfigVersion.cmake" DESTINATION ${install_cmake_suffix}) +endfunction() + + diff --git a/caffe-crfrnn/cmake/Cuda.cmake b/caffe-crfrnn/cmake/Cuda.cmake new file mode 100644 index 00000000..ff58d31c --- /dev/null +++ b/caffe-crfrnn/cmake/Cuda.cmake @@ -0,0 +1,254 @@ +if(CPU_ONLY) + return() +endif() + +# Known NVIDIA GPU achitectures Caffe can be compiled for. +# This list will be used for CUDA_ARCH_NAME = All option +set(Caffe_known_gpu_archs "20 21(20) 30 35 50") + +################################################################################################ +# A function for automatic detection of GPUs installed (if autodetection is enabled) +# Usage: +# caffe_detect_installed_gpus(out_variable) +function(caffe_detect_installed_gpus out_variable) + if(NOT CUDA_gpu_detect_output) + set(__cufile ${PROJECT_BINARY_DIR}/detect_cuda_archs.cu) + + file(WRITE ${__cufile} "" + "#include \n" + "int main()\n" + "{\n" + " int count = 0;\n" + " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" + " if (count == 0) return -1;\n" + " for (int device = 0; device < count; ++device)\n" + " {\n" + " cudaDeviceProp prop;\n" + " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" + " std::printf(\"%d.%d \", prop.major, prop.minor);\n" + " }\n" + " return 0;\n" + "}\n") + + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" "--run" "${__cufile}" + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE __nvcc_res OUTPUT_VARIABLE __nvcc_out + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(__nvcc_res EQUAL 0) + string(REPLACE "2.1" "2.1(2.0)" __nvcc_out "${__nvcc_out}") + set(CUDA_gpu_detect_output ${__nvcc_out} CACHE INTERNAL "Returned GPU architetures from caffe_detect_gpus tool" FORCE) + endif() + endif() + + if(NOT CUDA_gpu_detect_output) + message(STATUS "Automatic GPU detection failed. Building for all known architectures.") + set(${out_variable} ${Caffe_known_gpu_archs} PARENT_SCOPE) + else() + set(${out_variable} ${CUDA_gpu_detect_output} PARENT_SCOPE) + endif() +endfunction() + + +################################################################################################ +# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME +# Usage: +# caffe_select_nvcc_arch_flags(out_variable) +function(caffe_select_nvcc_arch_flags out_variable) + # List of arch names + set(__archs_names "Fermi" "Kepler" "Maxwell" "All" "Manual") + set(__archs_name_default "All") + if(NOT CMAKE_CROSSCOMPILING) + list(APPEND __archs_names "Auto") + set(__archs_name_default "Auto") + endif() + + # set CUDA_ARCH_NAME strings (so it will be seen as dropbox in CMake-Gui) + set(CUDA_ARCH_NAME ${__archs_name_default} CACHE STRING "Select target NVIDIA GPU achitecture.") + set_property( CACHE CUDA_ARCH_NAME PROPERTY STRINGS "" ${__archs_names} ) + mark_as_advanced(CUDA_ARCH_NAME) + + # verify CUDA_ARCH_NAME value + if(NOT ";${__archs_names};" MATCHES ";${CUDA_ARCH_NAME};") + string(REPLACE ";" ", " __archs_names "${__archs_names}") + message(FATAL_ERROR "Only ${__archs_names} architeture names are supported.") + endif() + + if(${CUDA_ARCH_NAME} STREQUAL "Manual") + set(CUDA_ARCH_BIN ${Caffe_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported") + set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for") + mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX) + else() + unset(CUDA_ARCH_BIN CACHE) + unset(CUDA_ARCH_PTX CACHE) + endif() + + if(${CUDA_ARCH_NAME} STREQUAL "Fermi") + set(__cuda_arch_bin "20 21(20)") + elseif(${CUDA_ARCH_NAME} STREQUAL "Kepler") + set(__cuda_arch_bin "30 35") + elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell") + set(__cuda_arch_bin "50") + elseif(${CUDA_ARCH_NAME} STREQUAL "All") + set(__cuda_arch_bin ${Caffe_known_gpu_archs}) + elseif(${CUDA_ARCH_NAME} STREQUAL "Auto") + caffe_detect_installed_gpus(__cuda_arch_bin) + else() # (${CUDA_ARCH_NAME} STREQUAL "Manual") + set(__cuda_arch_bin ${CUDA_ARCH_BIN}) + endif() + + # remove dots and convert to lists + string(REGEX REPLACE "\\." "" __cuda_arch_bin "${__cuda_arch_bin}") + string(REGEX REPLACE "\\." "" __cuda_arch_ptx "${CUDA_ARCH_PTX}") + string(REGEX MATCHALL "[0-9()]+" __cuda_arch_bin "${__cuda_arch_bin}") + string(REGEX MATCHALL "[0-9]+" __cuda_arch_ptx "${__cuda_arch_ptx}") + caffe_list_unique(__cuda_arch_bin __cuda_arch_ptx) + + set(__nvcc_flags "") + set(__nvcc_archs_readable "") + + # Tell NVCC to add binaries for the specified GPUs + foreach(__arch ${__cuda_arch_bin}) + if(__arch MATCHES "([0-9]+)\\(([0-9]+)\\)") + # User explicitly specified PTX for the concrete BIN + list(APPEND __nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) + list(APPEND __nvcc_archs_readable sm_${CMAKE_MATCH_1}) + else() + # User didn't explicitly specify PTX for the concrete BIN, we assume PTX=BIN + list(APPEND __nvcc_flags -gencode arch=compute_${__arch},code=sm_${__arch}) + list(APPEND __nvcc_archs_readable sm_${__arch}) + endif() + endforeach() + + # Tell NVCC to add PTX intermediate code for the specified architectures + foreach(__arch ${__cuda_arch_ptx}) + list(APPEND __nvcc_flags -gencode arch=compute_${__arch},code=compute_${__arch}) + list(APPEND __nvcc_archs_readable compute_${__arch}) + endforeach() + + string(REPLACE ";" " " __nvcc_archs_readable "${__nvcc_archs_readable}") + set(${out_variable} ${__nvcc_flags} PARENT_SCOPE) + set(${out_variable}_readable ${__nvcc_archs_readable} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Short command for cuda comnpilation +# Usage: +# caffe_cuda_compile( ) +macro(caffe_cuda_compile objlist_variable) + foreach(var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_DEBUG) + set(${var}_backup_in_cuda_compile_ "${${var}}") + + # we remove /EHa as it generates warnings under windows + string(REPLACE "/EHa" "" ${var} "${${var}}") + + endforeach() + + if(UNIX OR APPLE) + list(APPEND CUDA_NVCC_FLAGS -Xcompiler -fPIC) + endif() + + if(APPLE) + list(APPEND CUDA_NVCC_FLAGS -Xcompiler -Wno-unused-function) + endif() + + cuda_compile(cuda_objcs ${ARGN}) + + foreach(var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS_DEBUG) + set(${var} "${${var}_backup_in_cuda_compile_}") + unset(${var}_backup_in_cuda_compile_) + endforeach() + + set(${objlist_variable} ${cuda_objcs}) +endmacro() + +################################################################################################ +# Short command for cuDNN detection. Believe it soon will be a part of CUDA toolkit distribution. +# That's why not FindcuDNN.cmake file, but just the macro +# Usage: +# detect_cuDNN() +function(detect_cuDNN) + set(CUDNN_ROOT "" CACHE PATH "CUDNN root folder") + + find_path(CUDNN_INCLUDE cudnn.h + PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT} ${CUDA_TOOLKIT_INCLUDE} + DOC "Path to cuDNN include directory." ) + + get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH) + find_library(CUDNN_LIBRARY NAMES libcudnn.so # libcudnn_static.a + PATHS ${CUDNN_ROOT} $ENV{CUDNN_ROOT} ${CUDNN_INCLUDE} ${__libpath_hist} + DOC "Path to cuDNN library.") + + if(CUDNN_INCLUDE AND CUDNN_LIBRARY) + set(HAVE_CUDNN TRUE PARENT_SCOPE) + set(CUDNN_FOUND TRUE PARENT_SCOPE) + + mark_as_advanced(CUDNN_INCLUDE CUDNN_LIBRARY CUDNN_ROOT) + message(STATUS "Found cuDNN (include: ${CUDNN_INCLUDE}, library: ${CUDNN_LIBRARY})") + endif() +endfunction() + + +################################################################################################ +### Non macro section +################################################################################################ + +find_package(CUDA 5.5 QUIET) +find_cuda_helper_libs(curand) # cmake 2.8.7 compartibility which doesn't search for curand + +if(NOT CUDA_FOUND) + return() +endif() + +set(HAVE_CUDA TRUE) +message(STATUS "CUDA detected: " ${CUDA_VERSION}) +include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) +list(APPEND Caffe_LINKER_LIBS ${CUDA_CUDART_LIBRARY} + ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) + +# cudnn detection +if(USE_CUDNN) + detect_cuDNN() + if(HAVE_CUDNN) + add_definitions(-DUSE_CUDNN) + include_directories(SYSTEM ${CUDNN_INCLUDE}) + list(APPEND Caffe_LINKER_LIBS ${CUDNN_LIBRARY}) + endif() +endif() + +# setting nvcc arch flags +caffe_select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) +list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA}) +message(STATUS "Added CUDA NVCC flags for: ${NVCC_FLAGS_EXTRA_readable}") + +# Boost 1.55 workaround, see https://svn.boost.org/trac/boost/ticket/9392 or +# https://github.com/ComputationalRadiationPhysics/picongpu/blob/master/src/picongpu/CMakeLists.txt +if(Boost_VERSION EQUAL 105500) + message(STATUS "Cuda + Boost 1.55: Applying noinline work around") + # avoid warning for CMake >= 2.8.12 + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} \"-DBOOST_NOINLINE=__attribute__((noinline))\" ") +endif() + +# disable some nvcc diagnostic that apears in boost, glog, glags, opencv, etc. +foreach(diag cc_clobber_ignored integer_sign_change useless_using_declaration set_but_not_used) + list(APPEND CUDA_NVCC_FLAGS -Xcudafe --diag_suppress=${diag}) +endforeach() + +# setting default testing device +if(NOT CUDA_TEST_DEVICE) + set(CUDA_TEST_DEVICE -1) +endif() + +mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) +mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION) + +# Handle clang/libc++ issue +if(APPLE) + caffe_detect_darwin_version(OSX_VERSION) + + # OSX 10.9 and higher uses clang/libc++ by default which is incompartible with old CUDA toolkits + if(OSX_VERSION VERSION_GREATER 10.8) + # enabled by default if and only if CUDA version is less than 7.0 + caffe_option(USE_libstdcpp "Use libstdc++ instead of libc++" (CUDA_VERSION VERSION_LESS 7.0)) + endif() +endif() diff --git a/caffe-crfrnn/cmake/Dependencies.cmake b/caffe-crfrnn/cmake/Dependencies.cmake new file mode 100644 index 00000000..7c86dd55 --- /dev/null +++ b/caffe-crfrnn/cmake/Dependencies.cmake @@ -0,0 +1,158 @@ +# This list is required for static linking and exported to CaffeConfig.cmake +set(Caffe_LINKER_LIBS "") + +# ---[ Boost +find_package(Boost 1.46 REQUIRED COMPONENTS system thread) +include_directories(SYSTEM ${Boost_INCLUDE_DIR}) +list(APPEND Caffe_LINKER_LIBS ${Boost_LIBRARIES}) + +# ---[ Threads +find_package(Threads REQUIRED) +list(APPEND Caffe_LINKER_LIBS ${CMAKE_THREAD_LIBS_INIT}) + +# ---[ Google-glog +include("cmake/External/glog.cmake") +include_directories(SYSTEM ${GLOG_INCLUDE_DIRS}) +list(APPEND Caffe_LINKER_LIBS ${GLOG_LIBRARIES}) + +# ---[ Google-gflags +include("cmake/External/gflags.cmake") +include_directories(SYSTEM ${GFLAGS_INCLUDE_DIRS}) +list(APPEND Caffe_LINKER_LIBS ${GFLAGS_LIBRARIES}) + +# ---[ Google-protobuf +include(cmake/ProtoBuf.cmake) + +# ---[ HDF5 +find_package(HDF5 COMPONENTS HL REQUIRED) +include_directories(SYSTEM ${HDF5_INCLUDE_DIRS} ${HDF5_HL_INCLUDE_DIR}) +list(APPEND Caffe_LINKER_LIBS ${HDF5_LIBRARIES}) + +# ---[ LMDB +find_package(LMDB REQUIRED) +include_directories(SYSTEM ${LMDB_INCLUDE_DIR}) +list(APPEND Caffe_LINKER_LIBS ${LMDB_LIBRARIES}) + +# ---[ LevelDB +find_package(LevelDB REQUIRED) +include_directories(SYSTEM ${LevelDB_INCLUDE}) +list(APPEND Caffe_LINKER_LIBS ${LevelDB_LIBRARIES}) + +# ---[ Snappy +find_package(Snappy REQUIRED) +include_directories(SYSTEM ${Snappy_INCLUDE_DIR}) +list(APPEND Caffe_LINKER_LIBS ${Snappy_LIBRARIES}) + +# ---[ CUDA +include(cmake/Cuda.cmake) +if(NOT HAVE_CUDA) + if(CPU_ONLY) + message("-- CUDA is disabled. Building without it...") + else() + message("-- CUDA is not detected by cmake. Building without it...") + endif() + + # TODO: remove this not cross platform define in future. Use caffe_config.h instead. + add_definitions(-DCPU_ONLY) +endif() + +# ---[ OpenCV +find_package(OpenCV QUIET COMPONENTS core highgui imgproc imgcodecs) +if(NOT OpenCV_FOUND) # if not OpenCV 3.x, then imgcodecs are not found + find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc) +endif() +include_directories(SYSTEM ${OpenCV_INCLUDE_DIRS}) +list(APPEND Caffe_LINKER_LIBS ${OpenCV_LIBS}) +message(STATUS "OpenCV found (${OpenCV_CONFIG_PATH})") + +# ---[ BLAS +if(NOT APPLE) + set(BLAS "Atlas" CACHE STRING "Selected BLAS library") + set_property(CACHE BLAS PROPERTY STRINGS "Atlas;Open;MKL") + + if(BLAS STREQUAL "Atlas" OR BLAS STREQUAL "atlas") + find_package(Atlas REQUIRED) + include_directories(SYSTEM ${Atlas_INCLUDE_DIR}) + list(APPEND Caffe_LINKER_LIBS ${Atlas_LIBRARIES}) + elseif(BLAS STREQUAL "Open" OR BLAS STREQUAL "open") + find_package(OpenBLAS REQUIRED) + include_directories(SYSTEM ${OpenBLAS_INCLUDE_DIR}) + list(APPEND Caffe_LINKER_LIBS ${OpenBLAS_LIB}) + elseif(BLAS STREQUAL "MKL" OR BLAS STREQUAL "mkl") + find_package(MKL REQUIRED) + include_directories(SYSTEM ${MKL_INCLUDE_DIR}) + list(APPEND Caffe_LINKER_LIBS ${MKL_LIBRARIES}) + add_definitions(-DUSE_MKL) + endif() +elseif(APPLE) + find_package(vecLib REQUIRED) + include_directories(SYSTEM ${vecLib_INCLUDE_DIR}) + list(APPEND Caffe_LINKER_LIBS ${vecLib_LINKER_LIBS}) +endif() + +# ---[ Python +if(BUILD_python) + if(NOT "${python_version}" VERSION_LESS "3.0.0") + # use python3 + find_package(PythonInterp 3.0) + find_package(PythonLibs 3.0) + find_package(NumPy 1.7.1) + # Find the matching boost python implementation + set(version ${PYTHONLIBS_VERSION_STRING}) + + STRING( REPLACE "." "" boost_py_version ${version} ) + find_package(Boost 1.46 COMPONENTS "python-py${boost_py_version}") + set(Boost_PYTHON_FOUND ${Boost_PYTHON-PY${boost_py_version}_FOUND}) + + while(NOT "${version}" STREQUAL "" AND NOT Boost_PYTHON_FOUND) + STRING( REGEX REPLACE "([0-9.]+).[0-9]+" "\\1" version ${version} ) + + STRING( REPLACE "." "" boost_py_version ${version} ) + find_package(Boost 1.46 COMPONENTS "python-py${boost_py_version}") + set(Boost_PYTHON_FOUND ${Boost_PYTHON-PY${boost_py_version}_FOUND}) + + STRING( REGEX MATCHALL "([0-9.]+).[0-9]+" has_more_version ${version} ) + if("${has_more_version}" STREQUAL "") + break() + endif() + endwhile() + if(NOT Boost_PYTHON_FOUND) + find_package(Boost 1.46 COMPONENTS python) + endif() + else() + # disable Python 3 search + find_package(PythonInterp 2.7) + find_package(PythonLibs 2.7) + find_package(NumPy 1.7.1) + find_package(Boost 1.46 COMPONENTS python) + endif() + if(PYTHONLIBS_FOUND AND NUMPY_FOUND AND Boost_PYTHON_FOUND) + set(HAVE_PYTHON TRUE) + if(BUILD_python_layer) + add_definitions(-DWITH_PYTHON_LAYER) + include_directories(SYSTEM ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${Boost_INCLUDE_DIRS}) + list(APPEND Caffe_LINKER_LIBS ${PYTHON_LIBRARIES} ${Boost_LIBRARIES}) + endif() + endif() +endif() + +# ---[ Matlab +if(BUILD_matlab) + find_package(MatlabMex) + if(MATLABMEX_FOUND) + set(HAVE_MATLAB TRUE) + endif() + + # sudo apt-get install liboctave-dev + find_program(Octave_compiler NAMES mkoctfile DOC "Octave C++ compiler") + + if(HAVE_MATLAB AND Octave_compiler) + set(Matlab_build_mex_using "Matlab" CACHE STRING "Select Matlab or Octave if both detected") + set_property(CACHE Matlab_build_mex_using PROPERTY STRINGS "Matlab;Octave") + endif() +endif() + +# ---[ Doxygen +if(BUILD_docs) + find_package(Doxygen) +endif() diff --git a/caffe-crfrnn/cmake/External/gflags.cmake b/caffe-crfrnn/cmake/External/gflags.cmake new file mode 100644 index 00000000..e3dba04f --- /dev/null +++ b/caffe-crfrnn/cmake/External/gflags.cmake @@ -0,0 +1,56 @@ +if (NOT __GFLAGS_INCLUDED) # guard against multiple includes + set(__GFLAGS_INCLUDED TRUE) + + # use the system-wide gflags if present + find_package(GFlags) + if (GFLAGS_FOUND) + set(GFLAGS_EXTERNAL FALSE) + else() + # gflags will use pthreads if it's available in the system, so we must link with it + find_package(Threads) + + # build directory + set(gflags_PREFIX ${CMAKE_BINARY_DIR}/external/gflags-prefix) + # install directory + set(gflags_INSTALL ${CMAKE_BINARY_DIR}/external/gflags-install) + + # we build gflags statically, but want to link it into the caffe shared library + # this requires position-independent code + if (UNIX) + set(GFLAGS_EXTRA_COMPILER_FLAGS "-fPIC") + endif() + + set(GFLAGS_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${GFLAGS_EXTRA_COMPILER_FLAGS}) + set(GFLAGS_C_FLAGS ${CMAKE_C_FLAGS} ${GFLAGS_EXTRA_COMPILER_FLAGS}) + + ExternalProject_Add(gflags + PREFIX ${gflags_PREFIX} + GIT_REPOSITORY "https://github.com/gflags/gflags.git" + GIT_TAG "v2.1.2" + UPDATE_COMMAND "" + INSTALL_DIR ${gflags_INSTALL} + CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DCMAKE_INSTALL_PREFIX=${gflags_INSTALL} + -DBUILD_SHARED_LIBS=OFF + -DBUILD_STATIC_LIBS=ON + -DBUILD_PACKAGING=OFF + -DBUILD_TESTING=OFF + -DBUILD_NC_TESTS=OFF + -BUILD_CONFIG_TESTS=OFF + -DINSTALL_HEADERS=ON + -DCMAKE_C_FLAGS=${GFLAGS_C_FLAGS} + -DCMAKE_CXX_FLAGS=${GFLAGS_CXX_FLAGS} + LOG_DOWNLOAD 1 + LOG_INSTALL 1 + ) + + set(GFLAGS_FOUND TRUE) + set(GFLAGS_INCLUDE_DIRS ${gflags_INSTALL}/include) + set(GFLAGS_LIBRARIES ${gflags_INSTALL}/lib/libgflags.a ${CMAKE_THREAD_LIBS_INIT}) + set(GFLAGS_LIBRARY_DIRS ${gflags_INSTALL}/lib) + set(GFLAGS_EXTERNAL TRUE) + + list(APPEND external_project_dependencies gflags) + endif() + +endif() diff --git a/caffe-crfrnn/cmake/External/glog.cmake b/caffe-crfrnn/cmake/External/glog.cmake new file mode 100644 index 00000000..a44672f2 --- /dev/null +++ b/caffe-crfrnn/cmake/External/glog.cmake @@ -0,0 +1,56 @@ +# glog depends on gflags +include("cmake/External/gflags.cmake") + +if (NOT __GLOG_INCLUDED) + set(__GLOG_INCLUDED TRUE) + + # try the system-wide glog first + find_package(Glog) + if (GLOG_FOUND) + set(GLOG_EXTERNAL FALSE) + else() + # fetch and build glog from github + + # build directory + set(glog_PREFIX ${CMAKE_BINARY_DIR}/external/glog-prefix) + # install directory + set(glog_INSTALL ${CMAKE_BINARY_DIR}/external/glog-install) + + # we build glog statically, but want to link it into the caffe shared library + # this requires position-independent code + if (UNIX) + set(GLOG_EXTRA_COMPILER_FLAGS "-fPIC") + endif() + + set(GLOG_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${GLOG_EXTRA_COMPILER_FLAGS}) + set(GLOG_C_FLAGS ${CMAKE_C_FLAGS} ${GLOG_EXTRA_COMPILER_FLAGS}) + + # depend on gflags if we're also building it + if (GFLAGS_EXTERNAL) + set(GLOG_DEPENDS gflags) + endif() + + ExternalProject_Add(glog + DEPENDS ${GLOG_DEPENDS} + PREFIX ${glog_PREFIX} + GIT_REPOSITORY "https://github.com/google/glog" + GIT_TAG "v0.3.4" + UPDATE_COMMAND "" + INSTALL_DIR ${gflags_INSTALL} + CONFIGURE_COMMAND env "CFLAGS=${GLOG_C_FLAGS}" "CXXFLAGS=${GLOG_CXX_FLAGS}" ${glog_PREFIX}/src/glog/configure --prefix=${glog_INSTALL} --enable-shared=no --enable-static=yes --with-gflags=${GFLAGS_LIBRARY_DIRS}/.. + LOG_DOWNLOAD 1 + LOG_CONFIGURE 1 + LOG_INSTALL 1 + ) + + set(GLOG_FOUND TRUE) + set(GLOG_INCLUDE_DIRS ${glog_INSTALL}/include) + set(GLOG_LIBRARIES ${GFLAGS_LIBRARIES} ${glog_INSTALL}/lib/libglog.a) + set(GLOG_LIBRARY_DIRS ${glog_INSTALL}/lib) + set(GLOG_EXTERNAL TRUE) + + list(APPEND external_project_dependencies glog) + endif() + +endif() + diff --git a/caffe-crfrnn/cmake/Misc.cmake b/caffe-crfrnn/cmake/Misc.cmake new file mode 100644 index 00000000..9dd2609b --- /dev/null +++ b/caffe-crfrnn/cmake/Misc.cmake @@ -0,0 +1,52 @@ +# ---[ Configuration types +set(CMAKE_CONFIGURATION_TYPES "Debug;Release" CACHE STRING "Possible configurations" FORCE) +mark_as_advanced(CMAKE_CONFIGURATION_TYPES) + +if(DEFINED CMAKE_BUILD_TYPE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS ${CMAKE_CONFIGURATION_TYPES}) +endif() + +# --[ If user doesn't specify build type then assume release +if("${CMAKE_BUILD_TYPE}" STREQUAL "") + set(CMAKE_BUILD_TYPE Release) +endif() + +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(CMAKE_COMPILER_IS_CLANGXX TRUE) +endif() + +# ---[ Solution folders +caffe_option(USE_PROJECT_FOLDERS "IDE Solution folders" (MSVC_IDE OR CMAKE_GENERATOR MATCHES Xcode) ) + +if(USE_PROJECT_FOLDERS) + set_property(GLOBAL PROPERTY USE_FOLDERS ON) + set_property(GLOBAL PROPERTY PREDEFINED_TARGETS_FOLDER "CMakeTargets") +endif() + +# ---[ Install options +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${PROJECT_BINARY_DIR}/install" CACHE PATH "Default install path" FORCE) +endif() + +# ---[ RPATH settings +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE CACHE BOOLEAN "Use link paths for shared library rpath") +set(CMAKE_MACOSX_RPATH TRUE) + +list(FIND CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES ${CMAKE_INSTALL_PREFIX}/lib __is_systtem_dir) +if(${__is_systtem_dir} STREQUAL -1) + set(CMAKE_INSTALL_RPATH ${CMAKE_INSTALL_PREFIX}/lib) +endif() + +# ---[ Funny target +if(UNIX OR APPLE) + add_custom_target(symlink_to_build COMMAND "ln" "-sf" "${PROJECT_BINARY_DIR}" "${PROJECT_SOURCE_DIR}/build" + COMMENT "Adding symlink: /build -> ${PROJECT_BINARY_DIR}" ) +endif() + +# ---[ Set debug postfix +set(Caffe_DEBUG_POSTFIX "-d") + +set(Caffe_POSTFIX "") +if(CMAKE_BUILD_TYPE MATCHES "Debug") + set(Caffe_POSTFIX ${Caffe_DEBUG_POSTFIX}) +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindAtlas.cmake b/caffe-crfrnn/cmake/Modules/FindAtlas.cmake new file mode 100644 index 00000000..6e156435 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindAtlas.cmake @@ -0,0 +1,52 @@ +# Find the Atlas (and Lapack) libraries +# +# The following variables are optionally searched for defaults +# Atlas_ROOT_DIR: Base directory where all Atlas components are found +# +# The following are set after configuration is done: +# Atlas_FOUND +# Atlas_INCLUDE_DIRS +# Atlas_LIBRARIES +# Atlas_LIBRARYRARY_DIRS + +set(Atlas_INCLUDE_SEARCH_PATHS + /usr/include/atlas + /usr/include/atlas-base + $ENV{Atlas_ROOT_DIR} + $ENV{Atlas_ROOT_DIR}/include +) + +set(Atlas_LIB_SEARCH_PATHS + /usr/lib/atlas + /usr/lib/atlas-base + $ENV{Atlas_ROOT_DIR} + $ENV{Atlas_ROOT_DIR}/lib +) + +find_path(Atlas_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS}) +find_path(Atlas_CLAPACK_INCLUDE_DIR NAMES clapack.h PATHS ${Atlas_INCLUDE_SEARCH_PATHS}) + +find_library(Atlas_CBLAS_LIBRARY NAMES ptcblas_r ptcblas cblas_r cblas PATHS ${Atlas_LIB_SEARCH_PATHS}) +find_library(Atlas_BLAS_LIBRARY NAMES atlas_r atlas PATHS ${Atlas_LIB_SEARCH_PATHS}) +find_library(Atlas_LAPACK_LIBRARY NAMES alapack_r alapack lapack_atlas PATHS ${Atlas_LIB_SEARCH_PATHS}) + +set(LOOKED_FOR + Atlas_CBLAS_INCLUDE_DIR + Atlas_CLAPACK_INCLUDE_DIR + + Atlas_CBLAS_LIBRARY + Atlas_BLAS_LIBRARY + Atlas_LAPACK_LIBRARY +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Atlas DEFAULT_MSG ${LOOKED_FOR}) + +if(ATLAS_FOUND) + set(Atlas_INCLUDE_DIR ${Atlas_CBLAS_INCLUDE_DIR} ${Atlas_CLAPACK_INCLUDE_DIR}) + set(Atlas_LIBRARIES ${Atlas_LAPACK_LIBRARY} ${Atlas_CBLAS_LIBRARY} ${Atlas_BLAS_LIBRARY}) + mark_as_advanced(${LOOKED_FOR}) + + message(STATUS "Found Atlas (include: ${Atlas_CBLAS_INCLUDE_DIR}, library: ${Atlas_BLAS_LIBRARY})") +endif(ATLAS_FOUND) + diff --git a/caffe-crfrnn/cmake/Modules/FindGFlags.cmake b/caffe-crfrnn/cmake/Modules/FindGFlags.cmake new file mode 100644 index 00000000..29b60f05 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindGFlags.cmake @@ -0,0 +1,50 @@ +# - Try to find GFLAGS +# +# The following variables are optionally searched for defaults +# GFLAGS_ROOT_DIR: Base directory where all GFLAGS components are found +# +# The following are set after configuration is done: +# GFLAGS_FOUND +# GFLAGS_INCLUDE_DIRS +# GFLAGS_LIBRARIES +# GFLAGS_LIBRARYRARY_DIRS + +include(FindPackageHandleStandardArgs) + +set(GFLAGS_ROOT_DIR "" CACHE PATH "Folder contains Gflags") + +# We are testing only a couple of files in the include directories +if(WIN32) + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}/src/windows) +else() + find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h + PATHS ${GFLAGS_ROOT_DIR}) +endif() + +if(MSVC) + find_library(GFLAGS_LIBRARY_RELEASE + NAMES libgflags + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Release) + + find_library(GFLAGS_LIBRARY_DEBUG + NAMES libgflags-debug + PATHS ${GFLAGS_ROOT_DIR} + PATH_SUFFIXES Debug) + + set(GFLAGS_LIBRARY optimized ${GFLAGS_LIBRARY_RELEASE} debug ${GFLAGS_LIBRARY_DEBUG}) +else() + find_library(GFLAGS_LIBRARY gflags) +endif() + +find_package_handle_standard_args(GFlags DEFAULT_MSG GFLAGS_INCLUDE_DIR GFLAGS_LIBRARY) + + +if(GFLAGS_FOUND) + set(GFLAGS_INCLUDE_DIRS ${GFLAGS_INCLUDE_DIR}) + set(GFLAGS_LIBRARIES ${GFLAGS_LIBRARY}) + message(STATUS "Found gflags (include: ${GFLAGS_INCLUDE_DIR}, library: ${GFLAGS_LIBRARY})") + mark_as_advanced(GFLAGS_LIBRARY_DEBUG GFLAGS_LIBRARY_RELEASE + GFLAGS_LIBRARY GFLAGS_INCLUDE_DIR GFLAGS_ROOT_DIR) +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindGlog.cmake b/caffe-crfrnn/cmake/Modules/FindGlog.cmake new file mode 100644 index 00000000..99abbe47 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindGlog.cmake @@ -0,0 +1,48 @@ +# - Try to find Glog +# +# The following variables are optionally searched for defaults +# GLOG_ROOT_DIR: Base directory where all GLOG components are found +# +# The following are set after configuration is done: +# GLOG_FOUND +# GLOG_INCLUDE_DIRS +# GLOG_LIBRARIES +# GLOG_LIBRARYRARY_DIRS + +include(FindPackageHandleStandardArgs) + +set(GLOG_ROOT_DIR "" CACHE PATH "Folder contains Google glog") + +if(WIN32) + find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_ROOT_DIR}/src/windows) +else() + find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_ROOT_DIR}) +endif() + +if(MSVC) + find_library(GLOG_LIBRARY_RELEASE libglog_static + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES Release) + + find_library(GLOG_LIBRARY_DEBUG libglog_static + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES Debug) + + set(GLOG_LIBRARY optimized ${GLOG_LIBRARY_RELEASE} debug ${GLOG_LIBRARY_DEBUG}) +else() + find_library(GLOG_LIBRARY glog + PATHS ${GLOG_ROOT_DIR} + PATH_SUFFIXES lib lib64) +endif() + +find_package_handle_standard_args(Glog DEFAULT_MSG GLOG_INCLUDE_DIR GLOG_LIBRARY) + +if(GLOG_FOUND) + set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) + set(GLOG_LIBRARIES ${GLOG_LIBRARY}) + message(STATUS "Found glog (include: ${GLOG_INCLUDE_DIR}, library: ${GLOG_LIBRARY})") + mark_as_advanced(GLOG_ROOT_DIR GLOG_LIBRARY_RELEASE GLOG_LIBRARY_DEBUG + GLOG_LIBRARY GLOG_INCLUDE_DIR) +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindLAPACK.cmake b/caffe-crfrnn/cmake/Modules/FindLAPACK.cmake new file mode 100644 index 00000000..9641c45d --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindLAPACK.cmake @@ -0,0 +1,190 @@ +# - Find LAPACK library +# This module finds an installed fortran library that implements the LAPACK +# linear-algebra interface (see http://www.netlib.org/lapack/). +# +# The approach follows that taken for the autoconf macro file, acx_lapack.m4 +# (distributed at http://ac-archive.sourceforge.net/ac-archive/acx_lapack.html). +# +# This module sets the following variables: +# LAPACK_FOUND - set to true if a library implementing the LAPACK interface is found +# LAPACK_LIBRARIES - list of libraries (using full path name) for LAPACK + +# Note: I do not think it is a good idea to mixup different BLAS/LAPACK versions +# Hence, this script wants to find a Lapack library matching your Blas library + +# Do nothing if LAPACK was found before +IF(NOT LAPACK_FOUND) + +SET(LAPACK_LIBRARIES) +SET(LAPACK_INFO) + +IF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS) +ELSE(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + FIND_PACKAGE(BLAS REQUIRED) +ENDIF(LAPACK_FIND_QUIETLY OR NOT LAPACK_FIND_REQUIRED) + +# Old search lapack script +include(CheckFortranFunctionExists) + +macro(Check_Lapack_Libraries LIBRARIES _prefix _name _flags _list _blas) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + foreach(_library ${_list}) + set(_combined_name ${_combined_name}_${_library}) + if(_libraries_work) + if (WIN32) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} PATHS ENV LIB PATHS ENV PATH) + else (WIN32) + if(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV DYLD_LIBRARY_PATH) + else(APPLE) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 + ENV LD_LIBRARY_PATH) + endif(APPLE) + endif(WIN32) + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif(_libraries_work) + endforeach(_library ${_list}) + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_blas}) + if (CMAKE_Fortran_COMPILER_WORKS) + check_fortran_function_exists(${_name} ${_prefix}${_combined_name}_WORKS) + else (CMAKE_Fortran_COMPILER_WORKS) + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif (CMAKE_Fortran_COMPILER_WORKS) + set(CMAKE_REQUIRED_LIBRARIES) + mark_as_advanced(${_prefix}${_combined_name}_WORKS) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif(_libraries_work) + if(NOT _libraries_work) + set(${LIBRARIES} FALSE) + endif(NOT _libraries_work) +endmacro(Check_Lapack_Libraries) + + +if(BLAS_FOUND) + + # Intel MKL + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "mkl")) + IF(MKL_LAPACK_LIBRARIES) + SET(LAPACK_LIBRARIES ${MKL_LAPACK_LIBRARIES} ${MKL_LIBRARIES}) + ELSE(MKL_LAPACK_LIBRARIES) + SET(LAPACK_LIBRARIES ${MKL_LIBRARIES}) + ENDIF(MKL_LAPACK_LIBRARIES) + SET(LAPACK_INCLUDE_DIR ${MKL_INCLUDE_DIR}) + SET(LAPACK_INFO "mkl") + ENDIF() + + # OpenBlas + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "open")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" OPEN_LAPACK_WORKS) + if(OPEN_LAPACK_WORKS) + SET(LAPACK_INFO "open") + else() + message(STATUS "It seems OpenBlas has not been compiled with Lapack support") + endif() + endif() + + # GotoBlas + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "goto")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" GOTO_LAPACK_WORKS) + if(GOTO_LAPACK_WORKS) + SET(LAPACK_INFO "goto") + else() + message(STATUS "It seems GotoBlas has not been compiled with Lapack support") + endif() + endif() + + # ACML + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "acml")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" ACML_LAPACK_WORKS) + if(ACML_LAPACK_WORKS) + SET(LAPACK_INFO "acml") + else() + message(STATUS "Strangely, this ACML library does not support Lapack?!") + endif() + endif() + + # Accelerate + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "accelerate")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" ACCELERATE_LAPACK_WORKS) + if(ACCELERATE_LAPACK_WORKS) + SET(LAPACK_INFO "accelerate") + else() + message(STATUS "Strangely, this Accelerate library does not support Lapack?!") + endif() + endif() + + # vecLib + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "veclib")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" VECLIB_LAPACK_WORKS) + if(VECLIB_LAPACK_WORKS) + SET(LAPACK_INFO "veclib") + else() + message(STATUS "Strangely, this vecLib library does not support Lapack?!") + endif() + endif() + + # Generic LAPACK library? + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "generic")) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "lapack" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "generic") + endif(LAPACK_LIBRARIES) + endif() + +else(BLAS_FOUND) + message(STATUS "LAPACK requires BLAS") +endif(BLAS_FOUND) + +if(LAPACK_INFO) + set(LAPACK_FOUND TRUE) +else(LAPACK_INFO) + set(LAPACK_FOUND FALSE) +endif(LAPACK_INFO) + +IF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) + message(FATAL_ERROR "Cannot find a library with LAPACK API. Please specify library location.") +ENDIF (NOT LAPACK_FOUND AND LAPACK_FIND_REQUIRED) +IF(NOT LAPACK_FIND_QUIETLY) + IF(LAPACK_FOUND) + MESSAGE(STATUS "Found a library with LAPACK API. (${LAPACK_INFO})") + ELSE(LAPACK_FOUND) + MESSAGE(STATUS "Cannot find a library with LAPACK API. Not using LAPACK.") + ENDIF(LAPACK_FOUND) +ENDIF(NOT LAPACK_FIND_QUIETLY) + +# Do nothing if LAPACK was found before +ENDIF(NOT LAPACK_FOUND) diff --git a/caffe-crfrnn/cmake/Modules/FindLMDB.cmake b/caffe-crfrnn/cmake/Modules/FindLMDB.cmake new file mode 100644 index 00000000..8a817fd6 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindLMDB.cmake @@ -0,0 +1,28 @@ +# Try to find the LMBD libraries and headers +# LMDB_FOUND - system has LMDB lib +# LMDB_INCLUDE_DIR - the LMDB include directory +# LMDB_LIBRARIES - Libraries needed to use LMDB + +# FindCWD based on FindGMP by: +# Copyright (c) 2006, Laurent Montel, +# +# Redistribution and use is allowed according to the terms of the BSD license. + +# Adapted from FindCWD by: +# Copyright 2013 Conrad Steenberg +# Aug 31, 2013 + +find_path(LMDB_INCLUDE_DIR NAMES lmdb.h PATHS "$ENV{LMDB_DIR}/include") +find_library(LMDB_LIBRARIES NAMES lmdb PATHS "$ENV{LMDB_DIR}/lib" ) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(LMDB DEFAULT_MSG LMDB_INCLUDE_DIR LMDB_LIBRARIES) + +if(LMDB_FOUND) + message(STATUS "Found lmdb (include: ${LMDB_INCLUDE_DIR}, library: ${LMDB_LIBRARIES})") + mark_as_advanced(LMDB_INCLUDE_DIR LMDB_LIBRARIES) + + caffe_parse_header(${LMDB_INCLUDE_DIR}/lmdb.h + LMDB_VERSION_LINES MDB_VERSION_MAJOR MDB_VERSION_MINOR MDB_VERSION_PATCH) + set(LMDB_VERSION "${MDB_VERSION_MAJOR}.${MDB_VERSION_MINOR}.${MDB_VERSION_PATCH}") +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindLevelDB.cmake b/caffe-crfrnn/cmake/Modules/FindLevelDB.cmake new file mode 100644 index 00000000..97f08ac9 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindLevelDB.cmake @@ -0,0 +1,44 @@ +# - Find LevelDB +# +# LevelDB_INCLUDES - List of LevelDB includes +# LevelDB_LIBRARIES - List of libraries when using LevelDB. +# LevelDB_FOUND - True if LevelDB found. + +# Look for the header file. +find_path(LevelDB_INCLUDE NAMES leveldb/db.h + PATHS $ENV{LEVELDB_ROOT}/include /opt/local/include /usr/local/include /usr/include + DOC "Path in which the file leveldb/db.h is located." ) + +# Look for the library. +find_library(LevelDB_LIBRARY NAMES leveldb + PATHS /usr/lib $ENV{LEVELDB_ROOT}/lib + DOC "Path to leveldb library." ) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(LevelDB DEFAULT_MSG LevelDB_INCLUDE LevelDB_LIBRARY) + +if(LEVELDB_FOUND) + message(STATUS "Found LevelDB (include: ${LevelDB_INCLUDE}, library: ${LevelDB_LIBRARY})") + set(LevelDB_INCLUDES ${LevelDB_INCLUDE}) + set(LevelDB_LIBRARIES ${LevelDB_LIBRARY}) + mark_as_advanced(LevelDB_INCLUDE LevelDB_LIBRARY) + + if(EXISTS "${LevelDB_INCLUDE}/leveldb/db.h") + file(STRINGS "${LevelDB_INCLUDE}/leveldb/db.h" __version_lines + REGEX "static const int k[^V]+Version[ \t]+=[ \t]+[0-9]+;") + + foreach(__line ${__version_lines}) + if(__line MATCHES "[^k]+kMajorVersion[ \t]+=[ \t]+([0-9]+);") + set(LEVELDB_VERSION_MAJOR ${CMAKE_MATCH_1}) + elseif(__line MATCHES "[^k]+kMinorVersion[ \t]+=[ \t]+([0-9]+);") + set(LEVELDB_VERSION_MINOR ${CMAKE_MATCH_1}) + endif() + endforeach() + + if(LEVELDB_VERSION_MAJOR AND LEVELDB_VERSION_MINOR) + set(LEVELDB_VERSION "${LEVELDB_VERSION_MAJOR}.${LEVELDB_VERSION_MINOR}") + endif() + + caffe_clear_vars(__line __version_lines) + endif() +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindMKL.cmake b/caffe-crfrnn/cmake/Modules/FindMKL.cmake new file mode 100644 index 00000000..d2012db5 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindMKL.cmake @@ -0,0 +1,110 @@ +# Find the MKL libraries +# +# Options: +# +# MKL_USE_SINGLE_DYNAMIC_LIBRARY : use single dynamic library interface +# MKL_USE_STATIC_LIBS : use static libraries +# MKL_MULTI_THREADED : use multi-threading +# +# This module defines the following variables: +# +# MKL_FOUND : True mkl is found +# MKL_INCLUDE_DIR : unclude directory +# MKL_LIBRARIES : the libraries to link against. + + +# ---[ Options +caffe_option(MKL_USE_SINGLE_DYNAMIC_LIBRARY "Use single dynamic library interface" ON) +caffe_option(MKL_USE_STATIC_LIBS "Use static libraries" OFF IF NOT MKL_USE_SINGLE_DYNAMIC_LIBRARY) +caffe_option(MKL_MULTI_THREADED "Use multi-threading" ON IF NOT MKL_USE_SINGLE_DYNAMIC_LIBRARY) + +# ---[ Root folders +set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs") +find_path(MKL_ROOT include/mkl.h PATHS $ENV{MKL_ROOT} ${INTEL_ROOT}/mkl + DOC "Folder contains MKL") + +# ---[ Find include dir +find_path(MKL_INCLUDE_DIR mkl.h PATHS ${MKL_ROOT} PATH_SUFFIXES include) +set(__looked_for MKL_INCLUDE_DIR) + +# ---[ Find libraries +if(CMAKE_SIZEOF_VOID_P EQUAL 4) + set(__path_suffixes lib lib/ia32) +else() + set(__path_suffixes lib lib/intel64) +endif() + +set(__mkl_libs "") +if(MKL_USE_SINGLE_DYNAMIC_LIBRARY) + list(APPEND __mkl_libs rt) +else() + if(CMAKE_SIZEOF_VOID_P EQUAL 4) + if(WIN32) + list(APPEND __mkl_libs intel_c) + else() + list(APPEND __mkl_libs intel gf) + endif() + else() + list(APPEND __mkl_libs intel_lp64 gf_lp64) + endif() + + if(MKL_MULTI_THREADED) + list(APPEND __mkl_libs intel_thread) + else() + list(APPEND __mkl_libs sequential) + endif() + + list(APPEND __mkl_libs core cdft_core) +endif() + + +foreach (__lib ${__mkl_libs}) + set(__mkl_lib "mkl_${__lib}") + string(TOUPPER ${__mkl_lib} __mkl_lib_upper) + + if(MKL_USE_STATIC_LIBS) + set(__mkl_lib "lib${__mkl_lib}.a") + endif() + + find_library(${__mkl_lib_upper}_LIBRARY + NAMES ${__mkl_lib} + PATHS ${MKL_ROOT} "${MKL_INCLUDE_DIR}/.." + PATH_SUFFIXES ${__path_suffixes} + DOC "The path to Intel(R) MKL ${__mkl_lib} library") + mark_as_advanced(${__mkl_lib_upper}_LIBRARY) + + list(APPEND __looked_for ${__mkl_lib_upper}_LIBRARY) + list(APPEND MKL_LIBRARIES ${${__mkl_lib_upper}_LIBRARY}) +endforeach() + + +if(NOT MKL_USE_SINGLE_DYNAMIC_LIBRARY) + if (MKL_USE_STATIC_LIBS) + set(__iomp5_libs iomp5 libiomp5mt.lib) + else() + set(__iomp5_libs iomp5 libiomp5md.lib) + endif() + + if(WIN32) + find_path(INTEL_INCLUDE_DIR omp.h PATHS ${INTEL_ROOT} PATH_SUFFIXES include) + list(APPEND __looked_for INTEL_INCLUDE_DIR) + endif() + + find_library(MKL_RTL_LIBRARY ${__iomp5_libs} + PATHS ${INTEL_RTL_ROOT} ${INTEL_ROOT}/compiler ${MKL_ROOT}/.. ${MKL_ROOT}/../compiler + PATH_SUFFIXES ${__path_suffixes} + DOC "Path to Path to OpenMP runtime library") + + list(APPEND __looked_for MKL_RTL_LIBRARY) + list(APPEND MKL_LIBRARIES ${MKL_RTL_LIBRARY}) +endif() + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(MKL DEFAULT_MSG ${__looked_for}) + +if(MKL_FOUND) + message(STATUS "Found MKL (include: ${MKL_INCLUDE_DIR}, lib: ${MKL_LIBRARIES}") +endif() + +caffe_clear_vars(__looked_for __mkl_libs __path_suffixes __lib_suffix __iomp5_libs) diff --git a/caffe-crfrnn/cmake/Modules/FindMatlabMex.cmake b/caffe-crfrnn/cmake/Modules/FindMatlabMex.cmake new file mode 100644 index 00000000..28ae65e7 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindMatlabMex.cmake @@ -0,0 +1,48 @@ +# This module looks for MatlabMex compiler +# Defines variables: +# Matlab_DIR - Matlab root dir +# Matlab_mex - path to mex compiler +# Matlab_mexext - path to mexext + +if(MSVC) + foreach(__ver "9.30" "7.14" "7.11" "7.10" "7.9" "7.8" "7.7") + get_filename_component(__matlab_root "[HKEY_LOCAL_MACHINE\\SOFTWARE\\MathWorks\\MATLAB\\${__ver};MATLABROOT]" ABSOLUTE) + if(__matlab_root) + break() + endif() + endforeach() +endif() + +if(APPLE) + foreach(__ver "R2014b" "R2014a" "R2013b" "R2013a" "R2012b" "R2012a" "R2011b" "R2011a" "R2010b" "R2010a") + if(EXISTS /Applications/MATLAB_${__ver}.app) + set(__matlab_root /Applications/MATLAB_${__ver}.app) + break() + endif() + endforeach() +endif() + +if(UNIX) + execute_process(COMMAND which matlab OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE __out RESULT_VARIABLE __res) + + if(__res MATCHES 0) # Suppress `readlink` warning if `which` returned nothing + execute_process(COMMAND which matlab COMMAND xargs readlink + COMMAND xargs dirname COMMAND xargs dirname COMMAND xargs echo -n + OUTPUT_VARIABLE __matlab_root OUTPUT_STRIP_TRAILING_WHITESPACE) + endif() +endif() + + +find_path(Matlab_DIR NAMES bin/mex bin/mexext PATHS ${__matlab_root} + DOC "Matlab directory" NO_DEFAULT_PATH) + +find_program(Matlab_mex NAMES mex mex.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) +find_program(Matlab_mexext NAMES mexext mexext.bat HINTS ${Matlab_DIR} PATH_SUFFIXES bin NO_DEFAULT_PATH) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(MatlabMex DEFAULT_MSG Matlab_mex Matlab_mexext) + +if(MATLABMEX_FOUND) + mark_as_advanced(Matlab_mex Matlab_mexext) +endif() diff --git a/caffe-crfrnn/cmake/Modules/FindNumPy.cmake b/caffe-crfrnn/cmake/Modules/FindNumPy.cmake new file mode 100644 index 00000000..a671494c --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindNumPy.cmake @@ -0,0 +1,58 @@ +# - Find the NumPy libraries +# This module finds if NumPy is installed, and sets the following variables +# indicating where it is. +# +# TODO: Update to provide the libraries and paths for linking npymath lib. +# +# NUMPY_FOUND - was NumPy found +# NUMPY_VERSION - the version of NumPy found as a string +# NUMPY_VERSION_MAJOR - the major version number of NumPy +# NUMPY_VERSION_MINOR - the minor version number of NumPy +# NUMPY_VERSION_PATCH - the patch version number of NumPy +# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 +# NUMPY_INCLUDE_DIR - path to the NumPy include files + +unset(NUMPY_VERSION) +unset(NUMPY_INCLUDE_DIR) + +if(PYTHONINTERP_FOUND) + execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" + "import numpy as n; print(n.__version__); print(n.get_include());" + RESULT_VARIABLE __result + OUTPUT_VARIABLE __output + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(__result MATCHES 0) + string(REGEX REPLACE ";" "\\\\;" __values ${__output}) + string(REGEX REPLACE "\r?\n" ";" __values ${__values}) + list(GET __values 0 NUMPY_VERSION) + list(GET __values 1 NUMPY_INCLUDE_DIR) + + string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}") + if(NOT "${__ver_check}" STREQUAL "") + set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1}) + set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2}) + set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3}) + math(EXPR NUMPY_VERSION_DECIMAL + "(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") + string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR}) + else() + unset(NUMPY_VERSION) + unset(NUMPY_INCLUDE_DIR) + message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n") + endif() + endif() +else() + message(STATUS "To find NumPy Python interpretator is required to be found.") +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION + VERSION_VAR NUMPY_VERSION) + +if(NUMPY_FOUND) + message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})") +endif() + +caffe_clear_vars(__result __output __error_value __values __ver_check __error_value) + diff --git a/caffe-crfrnn/cmake/Modules/FindOpenBLAS.cmake b/caffe-crfrnn/cmake/Modules/FindOpenBLAS.cmake new file mode 100644 index 00000000..a6512ae7 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindOpenBLAS.cmake @@ -0,0 +1,64 @@ + + +SET(Open_BLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include +) + +SET(Open_BLAS_LIB_SEARCH_PATHS + /lib/ + /lib/openblas-base + /lib64/ + /usr/lib + /usr/lib/openblas-base + /usr/lib64 + /usr/local/lib + /usr/local/lib64 + /opt/OpenBLAS/lib + $ENV{OpenBLAS}cd + $ENV{OpenBLAS}/lib + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/lib + ) + +FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS}) +FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS}) + +SET(OpenBLAS_FOUND ON) + +# Check include files +IF(NOT OpenBLAS_INCLUDE_DIR) + SET(OpenBLAS_FOUND OFF) + MESSAGE(STATUS "Could not find OpenBLAS include. Turning OpenBLAS_FOUND off") +ENDIF() + +# Check libraries +IF(NOT OpenBLAS_LIB) + SET(OpenBLAS_FOUND OFF) + MESSAGE(STATUS "Could not find OpenBLAS lib. Turning OpenBLAS_FOUND off") +ENDIF() + +IF (OpenBLAS_FOUND) + IF (NOT OpenBLAS_FIND_QUIETLY) + MESSAGE(STATUS "Found OpenBLAS libraries: ${OpenBLAS_LIB}") + MESSAGE(STATUS "Found OpenBLAS include: ${OpenBLAS_INCLUDE_DIR}") + ENDIF (NOT OpenBLAS_FIND_QUIETLY) +ELSE (OpenBLAS_FOUND) + IF (OpenBLAS_FIND_REQUIRED) + MESSAGE(FATAL_ERROR "Could not find OpenBLAS") + ENDIF (OpenBLAS_FIND_REQUIRED) +ENDIF (OpenBLAS_FOUND) + +MARK_AS_ADVANCED( + OpenBLAS_INCLUDE_DIR + OpenBLAS_LIB + OpenBLAS +) + diff --git a/caffe-crfrnn/cmake/Modules/FindSnappy.cmake b/caffe-crfrnn/cmake/Modules/FindSnappy.cmake new file mode 100644 index 00000000..eff2a864 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindSnappy.cmake @@ -0,0 +1,28 @@ +# Find the Snappy libraries +# +# The following variables are optionally searched for defaults +# Snappy_ROOT_DIR: Base directory where all Snappy components are found +# +# The following are set after configuration is done: +# SNAPPY_FOUND +# Snappy_INCLUDE_DIR +# Snappy_LIBRARIES + +find_path(Snappy_INCLUDE_DIR NAMES snappy.h + PATHS ${SNAPPY_ROOT_DIR} ${SNAPPY_ROOT_DIR}/include) + +find_library(Snappy_LIBRARIES NAMES snappy + PATHS ${SNAPPY_ROOT_DIR} ${SNAPPY_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Snappy DEFAULT_MSG Snappy_INCLUDE_DIR Snappy_LIBRARIES) + +if(SNAPPY_FOUND) + message(STATUS "Found Snappy (include: ${Snappy_INCLUDE_DIR}, library: ${Snappy_LIBRARIES})") + mark_as_advanced(Snappy_INCLUDE_DIR Snappy_LIBRARIES) + + caffe_parse_header(${Snappy_INCLUDE_DIR}/snappy-stubs-public.h + SNAPPY_VERION_LINES SNAPPY_MAJOR SNAPPY_MINOR SNAPPY_PATCHLEVEL) + set(Snappy_VERSION "${SNAPPY_MAJOR}.${SNAPPY_MINOR}.${SNAPPY_PATCHLEVEL}") +endif() + diff --git a/caffe-crfrnn/cmake/Modules/FindvecLib.cmake b/caffe-crfrnn/cmake/Modules/FindvecLib.cmake new file mode 100644 index 00000000..9600da43 --- /dev/null +++ b/caffe-crfrnn/cmake/Modules/FindvecLib.cmake @@ -0,0 +1,34 @@ +# Find the vecLib libraries as part of Accelerate.framework or as standalon framework +# +# The following are set after configuration is done: +# VECLIB_FOUND +# vecLib_INCLUDE_DIR +# vecLib_LINKER_LIBS + + +if(NOT APPLE) + return() +endif() + +set(__veclib_include_suffix "Frameworks/vecLib.framework/Versions/Current/Headers") + +find_path(vecLib_INCLUDE_DIR vecLib.h + DOC "vecLib include directory" + PATHS /System/Library/${__veclib_include_suffix} + /System/Library/Frameworks/Accelerate.framework/Versions/Current/${__veclib_include_suffix} + /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX10.9.sdk/System/Library/Frameworks/Accelerate.framework/Versions/Current/Frameworks/vecLib.framework/Headers/) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(vecLib DEFAULT_MSG vecLib_INCLUDE_DIR) + +if(VECLIB_FOUND) + if(vecLib_INCLUDE_DIR MATCHES "^/System/Library/Frameworks/vecLib.framework.*") + set(vecLib_LINKER_LIBS -lcblas "-framework vecLib") + message(STATUS "Found standalone vecLib.framework") + else() + set(vecLib_LINKER_LIBS -lcblas "-framework Accelerate") + message(STATUS "Found vecLib as part of Accelerate.framework") + endif() + + mark_as_advanced(vecLib_INCLUDE_DIR) +endif() diff --git a/caffe-crfrnn/cmake/ProtoBuf.cmake b/caffe-crfrnn/cmake/ProtoBuf.cmake new file mode 100644 index 00000000..fc799bd3 --- /dev/null +++ b/caffe-crfrnn/cmake/ProtoBuf.cmake @@ -0,0 +1,90 @@ +# Finds Google Protocol Buffers library and compilers and extends +# the standard cmake script with version and python generation support + +find_package( Protobuf REQUIRED ) +include_directories(SYSTEM ${PROTOBUF_INCLUDE_DIR}) +list(APPEND Caffe_LINKER_LIBS ${PROTOBUF_LIBRARIES}) + +# As of Ubuntu 14.04 protoc is no longer a part of libprotobuf-dev package +# and should be installed separately as in: sudo apt-get install protobuf-compiler +if(EXISTS ${PROTOBUF_PROTOC_EXECUTABLE}) + message(STATUS "Found PROTOBUF Compiler: ${PROTOBUF_PROTOC_EXECUTABLE}") +else() + message(FATAL_ERROR "Could not find PROTOBUF Compiler") +endif() + +if(PROTOBUF_FOUND) + # fetches protobuf version + caffe_parse_header(${PROTOBUF_INCLUDE_DIR}/google/protobuf/stubs/common.h VERION_LINE GOOGLE_PROTOBUF_VERSION) + string(REGEX MATCH "([0-9])00([0-9])00([0-9])" PROTOBUF_VERSION ${GOOGLE_PROTOBUF_VERSION}) + set(PROTOBUF_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}") + unset(GOOGLE_PROTOBUF_VERSION) +endif() + +# place where to generate protobuf sources +set(proto_gen_folder "${PROJECT_BINARY_DIR}/include/caffe/proto") +include_directories(SYSTEM "${PROJECT_BINARY_DIR}/include") + +set(PROTOBUF_GENERATE_CPP_APPEND_PATH TRUE) + +################################################################################################ +# Modification of standard 'protobuf_generate_cpp()' with output dir parameter and python support +# Usage: +# caffe_protobuf_generate_cpp_py( ) +function(caffe_protobuf_generate_cpp_py output_dir srcs_var hdrs_var python_var) + if(NOT ARGN) + message(SEND_ERROR "Error: caffe_protobuf_generate_cpp_py() called without any proto files") + return() + endif() + + if(PROTOBUF_GENERATE_CPP_APPEND_PATH) + # Create an include path for each file specified + foreach(fil ${ARGN}) + get_filename_component(abs_fil ${fil} ABSOLUTE) + get_filename_component(abs_path ${abs_fil} PATH) + list(FIND _protoc_include ${abs_path} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protoc_include -I ${abs_path}) + endif() + endforeach() + else() + set(_protoc_include -I ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + + if(DEFINED PROTOBUF_IMPORT_DIRS) + foreach(dir ${PROTOBUF_IMPORT_DIRS}) + get_filename_component(abs_path ${dir} ABSOLUTE) + list(FIND _protoc_include ${abs_path} _contains_already) + if(${_contains_already} EQUAL -1) + list(APPEND _protoc_include -I ${abs_path}) + endif() + endforeach() + endif() + + set(${srcs_var}) + set(${hdrs_var}) + set(${python_var}) + foreach(fil ${ARGN}) + get_filename_component(abs_fil ${fil} ABSOLUTE) + get_filename_component(fil_we ${fil} NAME_WE) + + list(APPEND ${srcs_var} "${output_dir}/${fil_we}.pb.cc") + list(APPEND ${hdrs_var} "${output_dir}/${fil_we}.pb.h") + list(APPEND ${python_var} "${output_dir}/${fil_we}_pb2.py") + + add_custom_command( + OUTPUT "${output_dir}/${fil_we}.pb.cc" + "${output_dir}/${fil_we}.pb.h" + "${output_dir}/${fil_we}_pb2.py" + COMMAND ${CMAKE_COMMAND} -E make_directory "${output_dir}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --cpp_out ${output_dir} ${_protoc_include} ${abs_fil} + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${output_dir} ${_protoc_include} ${abs_fil} + DEPENDS ${abs_fil} + COMMENT "Running C++/Python protocol buffer compiler on ${fil}" VERBATIM ) + endforeach() + + set_source_files_properties(${${srcs_var}} ${${hdrs_var}} ${${python_var}} PROPERTIES GENERATED TRUE) + set(${srcs_var} ${${srcs_var}} PARENT_SCOPE) + set(${hdrs_var} ${${hdrs_var}} PARENT_SCOPE) + set(${python_var} ${${python_var}} PARENT_SCOPE) +endfunction() diff --git a/caffe-crfrnn/cmake/Summary.cmake b/caffe-crfrnn/cmake/Summary.cmake new file mode 100644 index 00000000..e094ac00 --- /dev/null +++ b/caffe-crfrnn/cmake/Summary.cmake @@ -0,0 +1,168 @@ +################################################################################################ +# Caffe status report function. +# Automatically align right column and selects text based on condition. +# Usage: +# caffe_status() +# caffe_status( [ ...]) +# caffe_status( THEN ELSE ) +function(caffe_status text) + set(status_cond) + set(status_then) + set(status_else) + + set(status_current_name "cond") + foreach(arg ${ARGN}) + if(arg STREQUAL "THEN") + set(status_current_name "then") + elseif(arg STREQUAL "ELSE") + set(status_current_name "else") + else() + list(APPEND status_${status_current_name} ${arg}) + endif() + endforeach() + + if(DEFINED status_cond) + set(status_placeholder_length 23) + string(RANDOM LENGTH ${status_placeholder_length} ALPHABET " " status_placeholder) + string(LENGTH "${text}" status_text_length) + if(status_text_length LESS status_placeholder_length) + string(SUBSTRING "${text}${status_placeholder}" 0 ${status_placeholder_length} status_text) + elseif(DEFINED status_then OR DEFINED status_else) + message(STATUS "${text}") + set(status_text "${status_placeholder}") + else() + set(status_text "${text}") + endif() + + if(DEFINED status_then OR DEFINED status_else) + if(${status_cond}) + string(REPLACE ";" " " status_then "${status_then}") + string(REGEX REPLACE "^[ \t]+" "" status_then "${status_then}") + message(STATUS "${status_text} ${status_then}") + else() + string(REPLACE ";" " " status_else "${status_else}") + string(REGEX REPLACE "^[ \t]+" "" status_else "${status_else}") + message(STATUS "${status_text} ${status_else}") + endif() + else() + string(REPLACE ";" " " status_cond "${status_cond}") + string(REGEX REPLACE "^[ \t]+" "" status_cond "${status_cond}") + message(STATUS "${status_text} ${status_cond}") + endif() + else() + message(STATUS "${text}") + endif() +endfunction() + + +################################################################################################ +# Function for fetching Caffe version from git and headers +# Usage: +# caffe_extract_caffe_version() +function(caffe_extract_caffe_version) + set(Caffe_GIT_VERSION "unknown") + find_package(Git) + if(GIT_FOUND) + execute_process(COMMAND ${GIT_EXECUTABLE} describe --tags --always --dirty + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + OUTPUT_VARIABLE Caffe_GIT_VERSION + RESULT_VARIABLE __git_result) + if(NOT ${__git_result} EQUAL 0) + set(Caffe_GIT_VERSION "unknown") + endif() + endif() + + set(Caffe_GIT_VERSION ${Caffe_GIT_VERSION} PARENT_SCOPE) + set(Caffe_VERSION " (Caffe doesn't declare its version in headers)" PARENT_SCOPE) + + # caffe_parse_header(${Caffe_INCLUDE_DIR}/caffe/version.hpp Caffe_VERSION_LINES CAFFE_MAJOR CAFFE_MINOR CAFFE_PATCH) + # set(Caffe_VERSION "${CAFFE_MAJOR}.${CAFFE_MINOR}.${CAFFE_PATCH}" PARENT_SCOPE) + + # or for #define Caffe_VERSION "x.x.x" + # caffe_parse_header_single_define(Caffe ${Caffe_INCLUDE_DIR}/caffe/version.hpp Caffe_VERSION) + # set(Caffe_VERSION ${Caffe_VERSION_STRING} PARENT_SCOPE) + +endfunction() + + +################################################################################################ +# Prints accumulated caffe configuration summary +# Usage: +# caffe_print_configuration_summary() + +function(caffe_print_configuration_summary) + caffe_extract_caffe_version() + set(Caffe_VERSION ${Caffe_VERSION} PARENT_SCOPE) + + caffe_merge_flag_lists(__flags_rel CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS) + caffe_merge_flag_lists(__flags_deb CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS) + + caffe_status("") + caffe_status("******************* Caffe Configuration Summary *******************") + caffe_status("General:") + caffe_status(" Version : ${Caffe_VERSION}") + caffe_status(" Git : ${Caffe_GIT_VERSION}") + caffe_status(" System : ${CMAKE_SYSTEM_NAME}") + caffe_status(" C++ compiler : ${CMAKE_CXX_COMPILER}") + caffe_status(" Release CXX flags : ${__flags_rel}") + caffe_status(" Debug CXX flags : ${__flags_deb}") + caffe_status(" Build type : ${CMAKE_BUILD_TYPE}") + caffe_status("") + caffe_status(" BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}") + caffe_status(" BUILD_python : ${BUILD_python}") + caffe_status(" BUILD_matlab : ${BUILD_matlab}") + caffe_status(" BUILD_docs : ${BUILD_docs}") + caffe_status(" CPU_ONLY : ${CPU_ONLY}") + caffe_status("") + caffe_status("Dependencies:") + caffe_status(" BLAS : " APPLE THEN "Yes (vecLib)" ELSE "Yes (${BLAS})") + caffe_status(" Boost : Yes (ver. ${Boost_MAJOR_VERSION}.${Boost_MINOR_VERSION})") + caffe_status(" glog : Yes") + caffe_status(" gflags : Yes") + caffe_status(" protobuf : " PROTOBUF_FOUND THEN "Yes (ver. ${PROTOBUF_VERSION})" ELSE "No" ) + caffe_status(" lmdb : " LMDB_FOUND THEN "Yes (ver. ${LMDB_VERSION})" ELSE "No") + caffe_status(" Snappy : " SNAPPY_FOUND THEN "Yes (ver. ${Snappy_VERSION})" ELSE "No" ) + caffe_status(" LevelDB : " LEVELDB_FOUND THEN "Yes (ver. ${LEVELDB_VERSION})" ELSE "No") + caffe_status(" OpenCV : Yes (ver. ${OpenCV_VERSION})") + caffe_status(" CUDA : " HAVE_CUDA THEN "Yes (ver. ${CUDA_VERSION})" ELSE "No" ) + caffe_status("") + if(HAVE_CUDA) + caffe_status("NVIDIA CUDA:") + caffe_status(" Target GPU(s) : ${CUDA_ARCH_NAME}" ) + caffe_status(" GPU arch(s) : ${NVCC_FLAGS_EXTRA_readable}") + if(USE_CUDNN) + caffe_status(" cuDNN : " HAVE_CUDNN THEN "Yes" ELSE "Not found") + else() + caffe_status(" cuDNN : Disabled") + endif() + caffe_status("") + endif() + if(HAVE_PYTHON) + caffe_status("Python:") + caffe_status(" Interpreter :" PYTHON_EXECUTABLE THEN "${PYTHON_EXECUTABLE} (ver. ${PYTHON_VERSION_STRING})" ELSE "No") + caffe_status(" Libraries :" PYTHONLIBS_FOUND THEN "${PYTHON_LIBRARIES} (ver ${PYTHONLIBS_VERSION_STRING})" ELSE "No") + caffe_status(" NumPy :" NUMPY_FOUND THEN "${NUMPY_INCLUDE_DIR} (ver ${NUMPY_VERSION})" ELSE "No") + caffe_status("") + endif() + if(BUILD_matlab) + caffe_status("Matlab:") + caffe_status(" Matlab :" HAVE_MATLAB THEN "Yes (${Matlab_mex}, ${Matlab_mexext}" ELSE "No") + caffe_status(" Octave :" Octave_compiler THEN "Yes (${Octave_compiler})" ELSE "No") + if(HAVE_MATLAB AND Octave_compiler) + caffe_status(" Build mex using : ${Matlab_build_mex_using}") + endif() + caffe_status("") + endif() + if(BUILD_docs) + caffe_status("Documentaion:") + caffe_status(" Doxygen :" DOXYGEN_FOUND THEN "${DOXYGEN_EXECUTABLE} (${DOXYGEN_VERSION})" ELSE "No") + caffe_status(" config_file : ${DOXYGEN_config_file}") + + caffe_status("") + endif() + caffe_status("Install:") + caffe_status(" Install path : ${CMAKE_INSTALL_PREFIX}") + caffe_status("") +endfunction() + diff --git a/caffe-crfrnn/cmake/Targets.cmake b/caffe-crfrnn/cmake/Targets.cmake new file mode 100644 index 00000000..4fc9456e --- /dev/null +++ b/caffe-crfrnn/cmake/Targets.cmake @@ -0,0 +1,173 @@ +################################################################################################ +# Defines global Caffe_LINK flag, This flag is required to prevent linker from excluding +# some objects which are not addressed directly but are registered via static constructors +if(BUILD_SHARED_LIBS) + set(Caffe_LINK caffe) +else() + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(Caffe_LINK -Wl,-force_load caffe) + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(Caffe_LINK -Wl,--whole-archive caffe -Wl,--no-whole-archive) + endif() +endif() + +################################################################################################ +# Convenient command to setup source group for IDEs that support this feature (VS, XCode) +# Usage: +# caffe_source_group( GLOB[_RECURSE] ) +function(caffe_source_group group) + cmake_parse_arguments(CAFFE_SOURCE_GROUP "" "" "GLOB;GLOB_RECURSE" ${ARGN}) + if(CAFFE_SOURCE_GROUP_GLOB) + file(GLOB srcs1 ${CAFFE_SOURCE_GROUP_GLOB}) + source_group(${group} FILES ${srcs1}) + endif() + + if(CAFFE_SOURCE_GROUP_GLOB_RECURSE) + file(GLOB_RECURSE srcs2 ${CAFFE_SOURCE_GROUP_GLOB_RECURSE}) + source_group(${group} FILES ${srcs2}) + endif() +endfunction() + +################################################################################################ +# Collecting sources from globbing and appending to output list variable +# Usage: +# caffe_collect_sources( GLOB[_RECURSE] ) +function(caffe_collect_sources variable) + cmake_parse_arguments(CAFFE_COLLECT_SOURCES "" "" "GLOB;GLOB_RECURSE" ${ARGN}) + if(CAFFE_COLLECT_SOURCES_GLOB) + file(GLOB srcs1 ${CAFFE_COLLECT_SOURCES_GLOB}) + set(${variable} ${variable} ${srcs1}) + endif() + + if(CAFFE_COLLECT_SOURCES_GLOB_RECURSE) + file(GLOB_RECURSE srcs2 ${CAFFE_COLLECT_SOURCES_GLOB_RECURSE}) + set(${variable} ${variable} ${srcs2}) + endif() +endfunction() + +################################################################################################ +# Short command getting caffe sources (assuming standard Caffe code tree) +# Usage: +# caffe_pickup_caffe_sources() +function(caffe_pickup_caffe_sources root) + # put all files in source groups (visible as subfolder in many IDEs) + caffe_source_group("Include" GLOB "${root}/include/caffe/*.h*") + caffe_source_group("Include\\Util" GLOB "${root}/include/caffe/util/*.h*") + caffe_source_group("Include" GLOB "${PROJECT_BINARY_DIR}/caffe_config.h*") + caffe_source_group("Source" GLOB "${root}/src/caffe/*.cpp") + caffe_source_group("Source\\Util" GLOB "${root}/src/caffe/util/*.cpp") + caffe_source_group("Source\\Layers" GLOB "${root}/src/caffe/layers/*.cpp") + caffe_source_group("Source\\Cuda" GLOB "${root}/src/caffe/layers/*.cu") + caffe_source_group("Source\\Cuda" GLOB "${root}/src/caffe/util/*.cu") + caffe_source_group("Source\\Proto" GLOB "${root}/src/caffe/proto/*.proto") + + # source groups for test target + caffe_source_group("Include" GLOB "${root}/include/caffe/test/test_*.h*") + caffe_source_group("Source" GLOB "${root}/src/caffe/test/test_*.cpp") + caffe_source_group("Source\\Cuda" GLOB "${root}/src/caffe/test/test_*.cu") + + # collect files + file(GLOB test_hdrs ${root}/include/caffe/test/test_*.h*) + file(GLOB test_srcs ${root}/src/caffe/test/test_*.cpp) + file(GLOB_RECURSE hdrs ${root}/include/caffe/*.h*) + file(GLOB_RECURSE srcs ${root}/src/caffe/*.cpp) + list(REMOVE_ITEM hdrs ${test_hdrs}) + list(REMOVE_ITEM srcs ${test_srcs}) + + # adding headers to make the visible in some IDEs (Qt, VS, Xcode) + list(APPEND srcs ${hdrs} ${PROJECT_BINARY_DIR}/caffe_config.h) + list(APPEND test_srcs ${test_hdrs}) + + # collect cuda files + file(GLOB test_cuda ${root}/src/caffe/test/test_*.cu) + file(GLOB_RECURSE cuda ${root}/src/caffe/*.cu) + list(REMOVE_ITEM cuda 6.5) + + # add proto to make them editable in IDEs too + file(GLOB_RECURSE proto_files ${root}/src/caffe/*.proto) + list(APPEND srcs ${proto_files}) + + # convet to absolute paths + caffe_convert_absolute_paths(srcs) + caffe_convert_absolute_paths(cuda) + caffe_convert_absolute_paths(test_srcs) + caffe_convert_absolute_paths(test_cuda) + + # propogate to parent scope + set(srcs ${srcs} PARENT_SCOPE) + set(cuda ${cuda} PARENT_SCOPE) + set(test_srcs ${test_srcs} PARENT_SCOPE) + set(test_cuda ${test_cuda} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Short command for setting defeault target properties +# Usage: +# caffe_default_properties() +function(caffe_default_properties target) + set_target_properties(${target} PROPERTIES + DEBUG_POSTFIX ${Caffe_DEBUG_POSTFIX} + ARCHIVE_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/lib" + LIBRARY_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") + # make sure we build all external depepdencies first + if (DEFINED external_project_dependencies) + add_dependencies(${target} ${external_project_dependencies}) + endif() +endfunction() + +################################################################################################ +# Short command for setting runtime directory for build target +# Usage: +# caffe_set_runtime_directory( ) +function(caffe_set_runtime_directory target dir) + set_target_properties(${target} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${dir}") +endfunction() + +################################################################################################ +# Short command for setting solution folder property for target +# Usage: +# caffe_set_solution_folder( ) +function(caffe_set_solution_folder target folder) + if(USE_PROJECT_FOLDERS) + set_target_properties(${target} PROPERTIES FOLDER "${folder}") + endif() +endfunction() + +################################################################################################ +# Reads lines from input file, prepends source directory to each line and writes to output file +# Usage: +# caffe_configure_testdatafile() +function(caffe_configure_testdatafile file) + file(STRINGS ${file} __lines) + set(result "") + foreach(line ${__lines}) + set(result "${result}${PROJECT_SOURCE_DIR}/${line}\n") + endforeach() + file(WRITE ${file}.gen.cmake ${result}) +endfunction() + +################################################################################################ +# Filter out all files that are not included in selected list +# Usage: +# caffe_leave_only_selected_tests( ) +function(caffe_leave_only_selected_tests file_list) + if(NOT ARGN) + return() # blank list means leave all + endif() + string(REPLACE "," ";" __selected ${ARGN}) + list(APPEND __selected caffe_main) + + set(result "") + foreach(f ${${file_list}}) + get_filename_component(name ${f} NAME_WE) + string(REGEX REPLACE "^test_" "" name ${name}) + list(FIND __selected ${name} __index) + if(NOT __index EQUAL -1) + list(APPEND result ${f}) + endif() + endforeach() + set(${file_list} ${result} PARENT_SCOPE) +endfunction() + diff --git a/caffe-crfrnn/cmake/Templates/CaffeConfig.cmake.in b/caffe-crfrnn/cmake/Templates/CaffeConfig.cmake.in new file mode 100644 index 00000000..8f23742e --- /dev/null +++ b/caffe-crfrnn/cmake/Templates/CaffeConfig.cmake.in @@ -0,0 +1,58 @@ +# Config file for the Caffe package. +# +# Note: +# Caffe and this config file depends on opencv, +# so put `find_package(OpenCV)` before searching Caffe +# via `find_package(Caffe)`. All other lib/includes +# dependencies are hard coded in the file +# +# After successful configuration the following variables +# will be defined: +# +# Caffe_INCLUDE_DIRS - Caffe include directories +# Caffe_LIBRARIES - libraries to link against +# Caffe_DEFINITIONS - a list of definitions to pass to compiler +# +# Caffe_HAVE_CUDA - signals about CUDA support +# Caffe_HAVE_CUDNN - signals about cuDNN support + + +# OpenCV dependency + +if(NOT OpenCV_FOUND) + set(Caffe_OpenCV_CONFIG_PATH "@OpenCV_CONFIG_PATH@") + if(Caffe_OpenCV_CONFIG_PATH) + get_filename_component(Caffe_OpenCV_CONFIG_PATH ${Caffe_OpenCV_CONFIG_PATH} ABSOLUTE) + + if(EXISTS ${Caffe_OpenCV_CONFIG_PATH} AND NOT TARGET opencv_core) + message(STATUS "Caffe: using OpenCV config from ${Caffe_OpenCV_CONFIG_PATH}") + include(${Caffe_OpenCV_CONFIG_PATH}/OpenCVModules.cmake) + endif() + + else() + find_package(OpenCV REQUIRED) + endif() + unset(Caffe_OpenCV_CONFIG_PATH) +endif() + +# Compute paths +get_filename_component(Caffe_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +set(Caffe_INCLUDE_DIRS "@Caffe_INCLUDE_DIRS@") + +@Caffe_INSTALL_INCLUDE_DIR_APPEND_COMMAND@ + +# Our library dependencies +if(NOT TARGET caffe AND NOT caffe_BINARY_DIR) + include("${Caffe_CMAKE_DIR}/CaffeTargets.cmake") +endif() + +# List of IMPORTED libs created by CaffeTargets.cmake +set(Caffe_LIBRARIES caffe) + +# Definitions +set(Caffe_DEFINITIONS "@Caffe_DEFINITIONS@") + +# Cuda support variables +set(Caffe_CPU_ONLY @CPU_ONLY@) +set(Caffe_HAVE_CUDA @HAVE_CUDA@) +set(Caffe_HAVE_CUDNN @HAVE_CUDNN@) diff --git a/caffe-crfrnn/cmake/Templates/CaffeConfigVersion.cmake.in b/caffe-crfrnn/cmake/Templates/CaffeConfigVersion.cmake.in new file mode 100644 index 00000000..19f85309 --- /dev/null +++ b/caffe-crfrnn/cmake/Templates/CaffeConfigVersion.cmake.in @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "@Caffe_VERSION@") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() diff --git a/caffe-crfrnn/cmake/Templates/caffe_config.h.in b/caffe-crfrnn/cmake/Templates/caffe_config.h.in new file mode 100644 index 00000000..6039e8f6 --- /dev/null +++ b/caffe-crfrnn/cmake/Templates/caffe_config.h.in @@ -0,0 +1,32 @@ +/* Sources directory */ +#define SOURCE_FOLDER "${PROJECT_SOURCE_DIR}" + +/* Binaries directory */ +#define BINARY_FOLDER "${PROJECT_BINARY_DIR}" + +/* NVIDA Cuda */ +#cmakedefine HAVE_CUDA + +/* NVIDA cuDNN */ +#cmakedefine HAVE_CUDNN +#cmakedefine USE_CUDNN + +/* NVIDA cuDNN */ +#cmakedefine CPU_ONLY + +/* Test device */ +#define CUDA_TEST_DEVICE ${CUDA_TEST_DEVICE} + +/* Temporary (TODO: remove) */ +#if 1 + #define CMAKE_SOURCE_DIR SOURCE_FOLDER "/src/" + #define EXAMPLES_SOURCE_DIR BINARY_FOLDER "/examples/" + #define CMAKE_EXT ".gen.cmake" +#else + #define CMAKE_SOURCE_DIR "src/" + #define EXAMPLES_SOURCE_DIR "examples/" + #define CMAKE_EXT "" +#endif + +/* Matlab */ +#cmakedefine HAVE_MATLAB diff --git a/caffe-crfrnn/cmake/Utils.cmake b/caffe-crfrnn/cmake/Utils.cmake new file mode 100644 index 00000000..a1bde1ae --- /dev/null +++ b/caffe-crfrnn/cmake/Utils.cmake @@ -0,0 +1,381 @@ +################################################################################################ +# Command alias for debugging messages +# Usage: +# dmsg() +function(dmsg) + message(STATUS ${ARGN}) +endfunction() + +################################################################################################ +# Removes duplicates from list(s) +# Usage: +# caffe_list_unique( [] [...]) +macro(caffe_list_unique) + foreach(__lst ${ARGN}) + if(${__lst}) + list(REMOVE_DUPLICATES ${__lst}) + endif() + endforeach() +endmacro() + +################################################################################################ +# Clears variables from list +# Usage: +# caffe_clear_vars() +macro(caffe_clear_vars) + foreach(_var ${ARGN}) + unset(${_var}) + endforeach() +endmacro() + +################################################################################################ +# Removes duplicates from string +# Usage: +# caffe_string_unique() +function(caffe_string_unique __string) + if(${__string}) + set(__list ${${__string}}) + separate_arguments(__list) + list(REMOVE_DUPLICATES __list) + foreach(__e ${__list}) + set(__str "${__str} ${__e}") + endforeach() + set(${__string} ${__str} PARENT_SCOPE) + endif() +endfunction() + +################################################################################################ +# Prints list element per line +# Usage: +# caffe_print_list() +function(caffe_print_list) + foreach(e ${ARGN}) + message(STATUS ${e}) + endforeach() +endfunction() + +################################################################################################ +# Function merging lists of compiler flags to single string. +# Usage: +# caffe_merge_flag_lists(out_variable [] [] ...) +function(caffe_merge_flag_lists out_var) + set(__result "") + foreach(__list ${ARGN}) + foreach(__flag ${${__list}}) + string(STRIP ${__flag} __flag) + set(__result "${__result} ${__flag}") + endforeach() + endforeach() + string(STRIP ${__result} __result) + set(${out_var} ${__result} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Converts all paths in list to absolute +# Usage: +# caffe_convert_absolute_paths() +function(caffe_convert_absolute_paths variable) + set(__dlist "") + foreach(__s ${${variable}}) + get_filename_component(__abspath ${__s} ABSOLUTE) + list(APPEND __list ${__abspath}) + endforeach() + set(${variable} ${__list} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Reads set of version defines from the header file +# Usage: +# caffe_parse_header( ..) +macro(caffe_parse_header FILENAME FILE_VAR) + set(vars_regex "") + set(__parnet_scope OFF) + set(__add_cache OFF) + foreach(name ${ARGN}) + if("${name}" STREQUAL "PARENT_SCOPE") + set(__parnet_scope ON) + elseif("${name}" STREQUAL "CACHE") + set(__add_cache ON) + elseif(vars_regex) + set(vars_regex "${vars_regex}|${name}") + else() + set(vars_regex "${name}") + endif() + endforeach() + if(EXISTS "${FILENAME}") + file(STRINGS "${FILENAME}" ${FILE_VAR} REGEX "#define[ \t]+(${vars_regex})[ \t]+[0-9]+" ) + else() + unset(${FILE_VAR}) + endif() + foreach(name ${ARGN}) + if(NOT "${name}" STREQUAL "PARENT_SCOPE" AND NOT "${name}" STREQUAL "CACHE") + if(${FILE_VAR}) + if(${FILE_VAR} MATCHES ".+[ \t]${name}[ \t]+([0-9]+).*") + string(REGEX REPLACE ".+[ \t]${name}[ \t]+([0-9]+).*" "\\1" ${name} "${${FILE_VAR}}") + else() + set(${name} "") + endif() + if(__add_cache) + set(${name} ${${name}} CACHE INTERNAL "${name} parsed from ${FILENAME}" FORCE) + elseif(__parnet_scope) + set(${name} "${${name}}" PARENT_SCOPE) + endif() + else() + unset(${name} CACHE) + endif() + endif() + endforeach() +endmacro() + +################################################################################################ +# Reads single version define from the header file and parses it +# Usage: +# caffe_parse_header_single_define( ) +function(caffe_parse_header_single_define LIBNAME HDR_PATH VARNAME) + set(${LIBNAME}_H "") + if(EXISTS "${HDR_PATH}") + file(STRINGS "${HDR_PATH}" ${LIBNAME}_H REGEX "^#define[ \t]+${VARNAME}[ \t]+\"[^\"]*\".*$" LIMIT_COUNT 1) + endif() + + if(${LIBNAME}_H) + string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MAJOR "${${LIBNAME}_H}") + string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_MINOR "${${LIBNAME}_H}") + string(REGEX REPLACE "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.[0-9]+\\.([0-9]+).*$" "\\1" ${LIBNAME}_VERSION_PATCH "${${LIBNAME}_H}") + set(${LIBNAME}_VERSION_MAJOR ${${LIBNAME}_VERSION_MAJOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_MINOR ${${LIBNAME}_VERSION_MINOR} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_PATCH ${${LIBNAME}_VERSION_PATCH} ${ARGN} PARENT_SCOPE) + set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_MAJOR}.${${LIBNAME}_VERSION_MINOR}.${${LIBNAME}_VERSION_PATCH}" PARENT_SCOPE) + + # append a TWEAK version if it exists: + set(${LIBNAME}_VERSION_TWEAK "") + if("${${LIBNAME}_H}" MATCHES "^.*[ \t]${VARNAME}[ \t]+\"[0-9]+\\.[0-9]+\\.[0-9]+\\.([0-9]+).*$") + set(${LIBNAME}_VERSION_TWEAK "${CMAKE_MATCH_1}" ${ARGN} PARENT_SCOPE) + endif() + if(${LIBNAME}_VERSION_TWEAK) + set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_STRING}.${${LIBNAME}_VERSION_TWEAK}" ${ARGN} PARENT_SCOPE) + else() + set(${LIBNAME}_VERSION_STRING "${${LIBNAME}_VERSION_STRING}" ${ARGN} PARENT_SCOPE) + endif() + endif() +endfunction() + +######################################################################################################## +# An option that the user can select. Can accept condition to control when option is available for user. +# Usage: +# caffe_option( "doc string" [IF ]) +function(caffe_option variable description value) + set(__value ${value}) + set(__condition "") + set(__varname "__value") + foreach(arg ${ARGN}) + if(arg STREQUAL "IF" OR arg STREQUAL "if") + set(__varname "__condition") + else() + list(APPEND ${__varname} ${arg}) + endif() + endforeach() + unset(__varname) + if("${__condition}" STREQUAL "") + set(__condition 2 GREATER 1) + endif() + + if(${__condition}) + if("${__value}" MATCHES ";") + if(${__value}) + option(${variable} "${description}" ON) + else() + option(${variable} "${description}" OFF) + endif() + elseif(DEFINED ${__value}) + if(${__value}) + option(${variable} "${description}" ON) + else() + option(${variable} "${description}" OFF) + endif() + else() + option(${variable} "${description}" ${__value}) + endif() + else() + unset(${variable} CACHE) + endif() +endfunction() + +################################################################################################ +# Utility macro for comparing two lists. Used for CMake debugging purposes +# Usage: +# caffe_compare_lists( [description]) +function(caffe_compare_lists list1 list2 desc) + set(__list1 ${${list1}}) + set(__list2 ${${list2}}) + list(SORT __list1) + list(SORT __list2) + list(LENGTH __list1 __len1) + list(LENGTH __list2 __len2) + + if(NOT ${__len1} EQUAL ${__len2}) + message(FATAL_ERROR "Lists are not equal. ${__len1} != ${__len2}. ${desc}") + endif() + + foreach(__i RANGE 1 ${__len1}) + math(EXPR __index "${__i}- 1") + list(GET __list1 ${__index} __item1) + list(GET __list2 ${__index} __item2) + if(NOT ${__item1} STREQUAL ${__item2}) + message(FATAL_ERROR "Lists are not equal. Differ at element ${__index}. ${desc}") + endif() + endforeach() +endfunction() + +################################################################################################ +# Command for disabling warnings for different platforms (see below for gcc and VisualStudio) +# Usage: +# caffe_warnings_disable( -Wshadow /wd4996 ..,) +macro(caffe_warnings_disable) + set(_flag_vars "") + set(_msvc_warnings "") + set(_gxx_warnings "") + + foreach(arg ${ARGN}) + if(arg MATCHES "^CMAKE_") + list(APPEND _flag_vars ${arg}) + elseif(arg MATCHES "^/wd") + list(APPEND _msvc_warnings ${arg}) + elseif(arg MATCHES "^-W") + list(APPEND _gxx_warnings ${arg}) + endif() + endforeach() + + if(NOT _flag_vars) + set(_flag_vars CMAKE_C_FLAGS CMAKE_CXX_FLAGS) + endif() + + if(MSVC AND _msvc_warnings) + foreach(var ${_flag_vars}) + foreach(warning ${_msvc_warnings}) + set(${var} "${${var}} ${warning}") + endforeach() + endforeach() + elseif((CMAKE_COMPILER_IS_GNUCXX OR CMAKE_COMPILER_IS_CLANGXX) AND _gxx_warnings) + foreach(var ${_flag_vars}) + foreach(warning ${_gxx_warnings}) + if(NOT warning MATCHES "^-Wno-") + string(REPLACE "${warning}" "" ${var} "${${var}}") + string(REPLACE "-W" "-Wno-" warning "${warning}") + endif() + set(${var} "${${var}} ${warning}") + endforeach() + endforeach() + endif() + caffe_clear_vars(_flag_vars _msvc_warnings _gxx_warnings) +endmacro() + +################################################################################################ +# Helper function get current definitions +# Usage: +# caffe_get_current_definitions() +function(caffe_get_current_definitions definitions_var) + get_property(current_definitions DIRECTORY PROPERTY COMPILE_DEFINITIONS) + set(result "") + + foreach(d ${current_definitions}) + list(APPEND result -D${d}) + endforeach() + + caffe_list_unique(result) + set(${definitions_var} ${result} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Helper function get current includes/definitions +# Usage: +# caffe_get_current_cflags() +function(caffe_get_current_cflags cflags_var) + get_property(current_includes DIRECTORY PROPERTY INCLUDE_DIRECTORIES) + caffe_convert_absolute_paths(current_includes) + caffe_get_current_definitions(cflags) + + foreach(i ${current_includes}) + list(APPEND cflags "-I${i}") + endforeach() + + caffe_list_unique(cflags) + set(${cflags_var} ${cflags} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Helper function to parse current linker libs into link directories, libflags and osx frameworks +# Usage: +# caffe_parse_linker_libs( ) +function(caffe_parse_linker_libs Caffe_LINKER_LIBS_variable folders_var flags_var frameworks_var) + + set(__unspec "") + set(__debug "") + set(__optimized "") + set(__framework "") + set(__varname "__unspec") + + # split libs into debug, optimized, unspecified and frameworks + foreach(list_elem ${${Caffe_LINKER_LIBS_variable}}) + if(list_elem STREQUAL "debug") + set(__varname "__debug") + elseif(list_elem STREQUAL "optimized") + set(__varname "__optimized") + elseif(list_elem MATCHES "^-framework[ \t]+([^ \t].*)") + list(APPEND __framework -framework ${CMAKE_MATCH_1}) + else() + list(APPEND ${__varname} ${list_elem}) + set(__varname "__unspec") + endif() + endforeach() + + # attach debug or optimized libs to unspecified according to current configuration + if(CMAKE_BUILD_TYPE MATCHES "Debug") + set(__libs ${__unspec} ${__debug}) + else() + set(__libs ${__unspec} ${__optimized}) + endif() + + set(libflags "") + set(folders "") + + # convert linker libraries list to link flags + foreach(lib ${__libs}) + if(TARGET ${lib}) + list(APPEND folders $) + list(APPEND libflags -l${lib}) + elseif(lib MATCHES "^-l.*") + list(APPEND libflags ${lib}) + elseif(IS_ABSOLUTE ${lib}) + get_filename_component(name_we ${lib} NAME_WE) + get_filename_component(folder ${lib} PATH) + + string(REGEX MATCH "^lib(.*)" __match ${name_we}) + list(APPEND libflags -l${CMAKE_MATCH_1}) + list(APPEND folders ${folder}) + else() + message(FATAL_ERROR "Logic error. Need to update cmake script") + endif() + endforeach() + + caffe_list_unique(libflags folders) + + set(${folders_var} ${folders} PARENT_SCOPE) + set(${flags_var} ${libflags} PARENT_SCOPE) + set(${frameworks_var} ${__framework} PARENT_SCOPE) +endfunction() + +################################################################################################ +# Helper function to detect Darwin version, i.e. 10.8, 10.9, 10.10, .... +# Usage: +# caffe_detect_darwin_version() +function(caffe_detect_darwin_version output_var) + if(APPLE) + execute_process(COMMAND /usr/bin/sw_vers -productVersion + RESULT_VARIABLE __sw_vers OUTPUT_VARIABLE __sw_vers_out + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + + set(${output_var} ${__sw_vers_out} PARENT_SCOPE) + else() + set(${output_var} "" PARENT_SCOPE) + endif() +endfunction() diff --git a/caffe-crfrnn/cmake/lint.cmake b/caffe-crfrnn/cmake/lint.cmake new file mode 100644 index 00000000..70a00657 --- /dev/null +++ b/caffe-crfrnn/cmake/lint.cmake @@ -0,0 +1,50 @@ + +set(CMAKE_SOURCE_DIR ..) +set(LINT_COMMAND ${CMAKE_SOURCE_DIR}/scripts/cpp_lint.py) +set(SRC_FILE_EXTENSIONS h hpp hu c cpp cu cc) +set(EXCLUDE_FILE_EXTENSTIONS pb.h pb.cc) +set(LINT_DIRS include src/caffe examples tools python matlab) + +cmake_policy(SET CMP0009 NEW) # suppress cmake warning + +# find all files of interest +foreach(ext ${SRC_FILE_EXTENSIONS}) + foreach(dir ${LINT_DIRS}) + file(GLOB_RECURSE FOUND_FILES ${CMAKE_SOURCE_DIR}/${dir}/*.${ext}) + set(LINT_SOURCES ${LINT_SOURCES} ${FOUND_FILES}) + endforeach() +endforeach() + +# find all files that should be excluded +foreach(ext ${EXCLUDE_FILE_EXTENSTIONS}) + file(GLOB_RECURSE FOUND_FILES ${CMAKE_SOURCE_DIR}/*.${ext}) + set(EXCLUDED_FILES ${EXCLUDED_FILES} ${FOUND_FILES}) +endforeach() + +# exclude generated pb files +list(REMOVE_ITEM LINT_SOURCES ${EXCLUDED_FILES}) + +execute_process( + COMMAND ${LINT_COMMAND} ${LINT_SOURCES} + ERROR_VARIABLE LINT_OUTPUT + ERROR_STRIP_TRAILING_WHITESPACE +) + +string(REPLACE "\n" ";" LINT_OUTPUT ${LINT_OUTPUT}) + +list(GET LINT_OUTPUT -1 LINT_RESULT) +list(REMOVE_AT LINT_OUTPUT -1) +string(REPLACE " " ";" LINT_RESULT ${LINT_RESULT}) +list(GET LINT_RESULT -1 NUM_ERRORS) +if(NUM_ERRORS GREATER 0) + foreach(msg ${LINT_OUTPUT}) + string(FIND ${msg} "Done" result) + if(result LESS 0) + message(STATUS ${msg}) + endif() + endforeach() + message(FATAL_ERROR "Lint found ${NUM_ERRORS} errors!") +else() + message(STATUS "Lint did not find any errors!") +endif() + diff --git a/caffe-crfrnn/data/cifar10/get_cifar10.sh b/caffe-crfrnn/data/cifar10/get_cifar10.sh new file mode 100755 index 00000000..623c8485 --- /dev/null +++ b/caffe-crfrnn/data/cifar10/get_cifar10.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env sh +# This scripts downloads the CIFAR10 (binary version) data and unzips it. + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +echo "Downloading..." + +wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz + +echo "Unzipping..." + +tar -xf cifar-10-binary.tar.gz && rm -f cifar-10-binary.tar.gz +mv cifar-10-batches-bin/* . && rm -rf cifar-10-batches-bin + +# Creation is split out because leveldb sometimes causes segfault +# and needs to be re-created. + +echo "Done." diff --git a/caffe-crfrnn/data/ilsvrc12/get_ilsvrc_aux.sh b/caffe-crfrnn/data/ilsvrc12/get_ilsvrc_aux.sh new file mode 100755 index 00000000..b9b85d21 --- /dev/null +++ b/caffe-crfrnn/data/ilsvrc12/get_ilsvrc_aux.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env sh +# +# N.B. This does not download the ilsvrcC12 data set, as it is gargantuan. +# This script downloads the imagenet example auxiliary files including: +# - the ilsvrc12 image mean, binaryproto +# - synset ids and words +# - Python pickle-format data of ImageNet graph structure and relative infogain +# - the training splits with labels + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +echo "Downloading..." + +wget http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz + +echo "Unzipping..." + +tar -xf caffe_ilsvrc12.tar.gz && rm -f caffe_ilsvrc12.tar.gz + +echo "Done." diff --git a/caffe-crfrnn/data/mnist/get_mnist.sh b/caffe-crfrnn/data/mnist/get_mnist.sh new file mode 100755 index 00000000..8eb6aeed --- /dev/null +++ b/caffe-crfrnn/data/mnist/get_mnist.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env sh +# This scripts downloads the mnist data and unzips it. + +DIR="$( cd "$(dirname "$0")" ; pwd -P )" +cd $DIR + +echo "Downloading..." + +wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz +wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz +wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz +wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz + +echo "Unzipping..." + +gunzip train-images-idx3-ubyte.gz +gunzip train-labels-idx1-ubyte.gz +gunzip t10k-images-idx3-ubyte.gz +gunzip t10k-labels-idx1-ubyte.gz + +# Creation is split out because leveldb sometimes causes segfault +# and needs to be re-created. + +echo "Done." diff --git a/caffe-crfrnn/docs/CMakeLists.txt b/caffe-crfrnn/docs/CMakeLists.txt new file mode 100644 index 00000000..ae47e461 --- /dev/null +++ b/caffe-crfrnn/docs/CMakeLists.txt @@ -0,0 +1,106 @@ +# Building docs script +# Requirements: +# sudo apt-get install doxygen texlive ruby-dev +# sudo gem install jekyll execjs therubyracer + +if(NOT BUILD_docs OR NOT DOXYGEN_FOUND) + return() +endif() + +################################################################################################# +# Gather docs from /examples/**/readme.md +function(gather_readmes_as_prebuild_cmd target gathered_dir root) + set(full_gathered_dir ${root}/${gathered_dir}) + + file(GLOB_RECURSE readmes ${root}/examples/readme.md ${root}/examples/README.md) + foreach(file ${readmes}) + # Only use file if it is to be included in docs. + file(STRINGS ${file} file_lines REGEX "include_in_docs: true") + + if(file_lines) + # Since everything is called readme.md, rename it by its dirname. + file(RELATIVE_PATH file ${root} ${file}) + get_filename_component(folder ${file} PATH) + set(new_filename ${full_gathered_dir}/${folder}.md) + + # folder value might be like /readme.md. That's why make directory. + get_filename_component(new_folder ${new_filename} PATH) + add_custom_command(TARGET ${target} PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${new_folder} + COMMAND ln -sf ${root}/${file} ${new_filename} + COMMENT "Creating symlink ${new_filename} -> ${root}/${file}" + WORKING_DIRECTORY ${root} VERBATIM) + endif() + endforeach() +endfunction() + +################################################################################################ +# Gather docs from examples/*.ipynb and add YAML front-matter. +function(gather_notebooks_as_prebuild_cmd target gathered_dir root) + set(full_gathered_dir ${root}/${gathered_dir}) + + if(NOT PYTHON_EXECUTABLE) + message(STATUS "Python interpeter is not found. Can't include *.ipynb files in docs. Skipping...") + return() + endif() + + file(GLOB_RECURSE notebooks ${root}/examples/*.ipynb) + foreach(file ${notebooks}) + file(RELATIVE_PATH file ${root} ${file}) + set(new_filename ${full_gathered_dir}/${file}) + + get_filename_component(new_folder ${new_filename} PATH) + add_custom_command(TARGET ${target} PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${new_folder} + COMMAND ${PYTHON_EXECUTABLE} scripts/copy_notebook.py ${file} ${new_filename} + COMMENT "Copying notebook ${file} to ${new_filename}" + WORKING_DIRECTORY ${root} VERBATIM) + endforeach() + + set(${outputs_var} ${outputs} PARENT_SCOPE) +endfunction() + +################################################################################################ +########################## [ Non macro part ] ################################################## + +# Gathering is done at each 'make doc' +file(REMOVE_RECURSE ${PROJECT_SOURCE_DIR}/docs/gathered) + +# Doxygen config file path +set(DOXYGEN_config_file ${PROJECT_SOURCE_DIR}/.Doxyfile CACHE FILEPATH "Doxygen config file") + +# Adding docs target +add_custom_target(docs COMMAND ${DOXYGEN_EXECUTABLE} ${DOXYGEN_config_file} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMENT "Launching doxygen..." VERBATIM) + +# Gathering examples into docs subfolder +gather_notebooks_as_prebuild_cmd(docs docs/gathered ${PROJECT_SOURCE_DIR}) +gather_readmes_as_prebuild_cmd(docs docs/gathered ${PROJECT_SOURCE_DIR}) + +# Auto detect output directory +file(STRINGS ${DOXYGEN_config_file} config_line REGEX "OUTPUT_DIRECTORY[ \t]+=[^=].*") +if(config_line) + string(REGEX MATCH "OUTPUT_DIRECTORY[ \t]+=([^=].*)" __ver_check "${config_line}") + string(STRIP ${CMAKE_MATCH_1} output_dir) + message(STATUS "Detected Doxygen OUTPUT_DIRECTORY: ${output_dir}") +else() + set(output_dir ./doxygen/) + message(STATUS "Can't find OUTPUT_DIRECTORY in doxygen config file. Try to use default: ${output_dir}") +endif() + +if(NOT IS_ABSOLUTE ${output_dir}) + set(output_dir ${PROJECT_SOURCE_DIR}/${output_dir}) + get_filename_component(output_dir ${output_dir} ABSOLUTE) +endif() + +# creates symlink in docs subfolder to code documentation built by doxygen +add_custom_command(TARGET docs POST_BUILD VERBATIM + COMMAND ln -sfn "${output_dir}/html" doxygen + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/docs + COMMENT "Creating symlink ${PROJECT_SOURCE_DIR}/docs/doxygen -> ${output_dir}/html") + +# for quick launch of jekyll +add_custom_target(jekyll COMMAND jekyll serve -w -s . -d _site --port=4000 + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/docs + COMMENT "Launching jekyll..." VERBATIM) diff --git a/caffe-crfrnn/docs/CNAME b/caffe-crfrnn/docs/CNAME new file mode 100644 index 00000000..eee1ae26 --- /dev/null +++ b/caffe-crfrnn/docs/CNAME @@ -0,0 +1 @@ +caffe.berkeleyvision.org diff --git a/caffe-crfrnn/docs/README.md b/caffe-crfrnn/docs/README.md new file mode 100644 index 00000000..8f1781e3 --- /dev/null +++ b/caffe-crfrnn/docs/README.md @@ -0,0 +1,5 @@ +# Caffe Documentation + +To generate the documentation, run `$CAFFE_ROOT/scripts/build_docs.sh`. + +To push your changes to the documentation to the gh-pages branch of your or the BVLC repo, run `$CAFFE_ROOT/scripts/deploy_docs.sh `. diff --git a/caffe-crfrnn/docs/_config.yml b/caffe-crfrnn/docs/_config.yml new file mode 100644 index 00000000..95aec12b --- /dev/null +++ b/caffe-crfrnn/docs/_config.yml @@ -0,0 +1,7 @@ +defaults: + - + scope: + path: "" # an empty string here means all files in the project + values: + layout: "default" + diff --git a/caffe-crfrnn/docs/_layouts/default.html b/caffe-crfrnn/docs/_layouts/default.html new file mode 100644 index 00000000..73c6d587 --- /dev/null +++ b/caffe-crfrnn/docs/_layouts/default.html @@ -0,0 +1,52 @@ + + + + + + + + + Caffe {% if page contains 'title' %}| {{ page.title }}{% endif %} + + + + + + + + + + + +
+
+

Caffe

+

+ Deep learning framework developed by Yangqing Jia / BVLC +

+ +
+
+ + {{ content }} + +
+
+ + diff --git a/caffe-crfrnn/docs/development.md b/caffe-crfrnn/docs/development.md new file mode 100644 index 00000000..dfed3308 --- /dev/null +++ b/caffe-crfrnn/docs/development.md @@ -0,0 +1,125 @@ +--- +title: Developing and Contributing +--- +# Development + +Caffe is developed with active participation of the community.
+The [BVLC](http://bvlc.eecs.berkeley.edu/) maintainers welcome all contributions! + +The exact details of contributions are recorded by versioning and cited in our [acknowledgements](http://caffe.berkeleyvision.org/#acknowledgements). +This method is impartial and always up-to-date. + +## License + +Caffe is licensed under the terms in [LICENSE](https://github.com/BVLC/caffe/blob/master/LICENSE). By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. + +## Copyright + +Caffe uses a shared copyright model: each contributor holds copyright over their contributions to Caffe. The project versioning records all such contribution and copyright details. + +If a contributor wants to further mark their specific copyright on a particular contribution, they should indicate their copyright solely in the commit message of the change when it is committed. Do not include copyright notices in files for this purpose. + +### Documentation + +This website, written with [Jekyll](http://jekyllrb.com/), functions as the official Caffe documentation -- simply run `scripts/build_docs.sh` and view the website at `http://0.0.0.0:4000`. + +We prefer tutorials and examples to be documented close to where they live, in `readme.md` files. +The `build_docs.sh` script gathers all `examples/**/readme.md` and `examples/*.ipynb` files, and makes a table of contents. +To be included in the docs, the readme files must be annotated with [YAML front-matter](http://jekyllrb.com/docs/frontmatter/), including the flag `include_in_docs: true`. +Similarly for IPython notebooks: simply include `"include_in_docs": true` in the `"metadata"` JSON field. + +Other docs, such as installation guides, are written in the `docs` directory and manually linked to from the `index.md` page. + +We strive to provide provide lots of usage examples, and to document all code in docstrings. +We absolutely appreciate any contribution to this effort! + +### The release cycle + +- The `dev` branch receives all new development, including community contributions. +We aim to keep it in a functional state, but large changes do occur, and things do get broken every now and then. +Use only if you want the "bleeding edge". +- BVLC maintainers will periodically update the `master` branch with changes from `dev`, giving it a release tag ([releases so far](https://github.com/BVLC/caffe/releases)). +Use this if you want more stability. + +### Issues & Pull Request Protocol + +Use Github Issues to report [bugs], propose features, and ask development [questions]. +Large-scale development work is guided by [milestones], which are sets of Issues selected for concurrent release (integration from `dev` to `master`). + +Please note that since the core developers are largely researchers, we may work on a feature in isolation for some time before releasing it to the community, so as to claim honest academic contribution. +We do release things as soon as a reasonable technical report may be written, and we still aim to inform the community of ongoing development through Github Issues. + +When you are ready to start developing your feature or fixing a bug, follow this protocol: + +- Develop in [feature branches] with descriptive names. + - For new development branch off `dev`. + - For documentation and fixes for `master` branch off `master`. +- Bring your work up-to-date by [rebasing] onto the latest `dev` / `master`. +(Polish your changes by [interactive rebase], if you'd like.) +- [Pull request] your contribution to `BVLC/caffe`'s `dev` / `master` branch for discussion and review. + - Make PRs *as soon as development begins*, to let discussion guide development. + - A PR is only ready for merge review when it is a fast-forward merge, and all code is documented, linted, and tested -- that means your PR must include tests! +- When the PR satisfies the above properties, use comments to request maintainer review. + +Below is a poetic presentation of the protocol in code form. + +#### [Shelhamer's](https://github.com/shelhamer) “life of a branch in four acts” + +Make the `feature` branch off of the latest `bvlc/dev` +``` +git checkout dev +git pull upstream dev +git checkout -b feature +# do your work, make commits +``` + +Prepare to merge by rebasing your branch on the latest `bvlc/dev` +``` +# make sure dev is fresh +git checkout dev +git pull upstream dev +# rebase your branch on the tip of dev +git checkout feature +git rebase dev +``` + +Push your branch to pull request it into `dev` +``` +git push origin feature +# ...make pull request to dev... +``` + +Now make a pull request! You can do this from the command line (`git pull-request -b dev`) if you install [hub](https://github.com/github/hub). + +The pull request of `feature` into `dev` will be a clean merge. Applause. + +[bugs]: https://github.com/BVLC/caffe/issues?labels=bug&page=1&state=open +[questions]: https://github.com/BVLC/caffe/issues?labels=question&page=1&state=open +[milestones]: https://github.com/BVLC/caffe/issues?milestone=1 +[Pull request]: https://help.github.com/articles/using-pull-requests +[interactive rebase]: https://help.github.com/articles/interactive-rebase +[rebasing]: http://git-scm.com/book/en/Git-Branching-Rebasing +[feature branches]: https://www.atlassian.com/git/workflows#!workflow-feature-branch + +### Testing + +Run `make runtest` to check the project tests. New code requires new tests. Pull requests that fail tests will not be accepted. + +The `googletest` framework we use provides many additional options, which you can access by running the test binaries directly. One of the more useful options is `--gtest_filter`, which allows you to filter tests by name: + + # run all tests with CPU in the name + build/test/test_all.testbin --gtest_filter='*CPU*' + + # run all tests without GPU in the name (note the leading minus sign) + build/test/test_all.testbin --gtest_filter=-'*GPU*' + +To get a list of all options `googletest` provides, simply pass the `--help` flag: + + build/test/test_all.testbin --help + +### Style + +- Follow [Google C++ style](http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml) and [Google python style](http://google-styleguide.googlecode.com/svn/trunk/pyguide.html) + [PEP 8](http://legacy.python.org/dev/peps/pep-0008/). +- Wrap lines at 80 chars. +- Remember that “a foolish consistency is the hobgoblin of little minds,” so use your best judgement to write the clearest code for your particular case. +- **Run `make lint` to check C++ code.** diff --git a/caffe-crfrnn/docs/index.md b/caffe-crfrnn/docs/index.md new file mode 100644 index 00000000..e90b06b4 --- /dev/null +++ b/caffe-crfrnn/docs/index.md @@ -0,0 +1,102 @@ +--- +title: Deep Learning Framework +--- + +# Caffe + +Caffe is a deep learning framework developed with cleanliness, readability, and speed in mind. +It was created by [Yangqing Jia](http://daggerfs.com) during his PhD at UC Berkeley, and is in active development by the Berkeley Vision and Learning Center ([BVLC](http://bvlc.eecs.berkeley.edu)) and by community contributors. +Caffe is released under the [BSD 2-Clause license](https://github.com/BVLC/caffe/blob/master/LICENSE). + +Check out our web image classification [demo](http://demo.caffe.berkeleyvision.org)! + +## Why use Caffe? + +**Clean architecture** enables rapid deployment. +Networks are specified in simple config files, with no hard-coded parameters in the code. +Switching between CPU and GPU is as simple as setting a flag -- so models can be trained on a GPU machine, and then used on commodity clusters. + +**Readable & modifiable implementation** fosters active development. +In Caffe's first six months, it has been forked by over 300 developers on Github, and many have pushed significant changes. + +**Speed** makes Caffe perfect for industry use. +Caffe can process over **40M images per day** with a single NVIDIA K40 or Titan GPU\*. +That's 5 ms/image in training, and 2 ms/image in test. +We believe that Caffe is the fastest CNN implementation available. + +**Community**: Caffe already powers academic research projects, startup prototypes, and even large-scale industrial applications in vision, speech, and multimedia. +There is an active discussion and support community on [Github](https://github.com/BVLC/caffe/issues). + +

+\* When files are properly cached, and using the ILSVRC2012-winning [SuperVision](http://www.image-net.org/challenges/LSVRC/2012/supervision.pdf) model. +Consult performance [details](/performance_hardware.html). +

+ +## Documentation + +- [DIY Deep Learning for Vision with Caffe](https://docs.google.com/presentation/d/1UeKXVgRvvxg9OUdh_UiC5G71UMscNPlvArsWER41PsU/edit#slide=id.p)
+Caffe tutorial slides. +- [ACM MM paper](http://ucb-icsi-vision-group.github.io/caffe-paper/caffe.pdf)
+A 4-page report for the ACM Multimedia Open Source competition. +- [Caffe Tutorial](/tutorial)
+DIY deep learning with this hands-on tutorial to Caffe. +- [Installation instructions](/installation.html)
+Tested on Ubuntu, Red Hat, OS X. +* [Model Zoo](/model_zoo.html)
+BVLC suggests a standard distribution format for Caffe models, and provides trained models. +* [Developing & Contributing](/development.html)
+Guidelines for development and contributing to Caffe. +* [API Documentation](/doxygen/)
+Developer documentation automagically generated from code comments. + +### Examples + +{% assign examples = site.pages | where:'category','example' | sort: 'priority' %} +{% for page in examples %} +-
{{page.title}}
{{page.description}}
+{% endfor %} + +### Notebook examples + +{% assign notebooks = site.pages | where:'category','notebook' | sort: 'priority' %} +{% for page in notebooks %} +-
{{page.title}}
{{page.description}}
+{% endfor %} + +## Citing Caffe + +Please cite Caffe in your publications if it helps your research: + + @misc{Jia13caffe, + Author = {Yangqing Jia}, + Title = { {Caffe}: An Open Source Convolutional Architecture for Fast Feature Embedding}, + Year = {2013}, + Howpublished = {\url{http://caffe.berkeleyvision.org/}} + } + +If you do publish a paper where Caffe helped your research, we encourage you to update the [publications wiki](https://github.com/BVLC/caffe/wiki/Publications). +Citations are also tracked automatically by [Google Scholar](http://scholar.google.com/scholar?oi=bibs&hl=en&cites=17333247995453974016). + +## Acknowledgements + +Yangqing would like to thank the NVIDIA Academic program for providing GPUs, [Oriol Vinyals](http://www1.icsi.berkeley.edu/~vinyals/) for discussions along the journey, and BVLC PI [Trevor Darrell](http://www.eecs.berkeley.edu/~trevor/) for guidance. + +A core set of BVLC members have contributed much new functionality and many fixes since the original release (alphabetical by first name): +[Eric Tzeng](https://github.com/erictzeng), [Evan Shelhamer](http://imaginarynumber.net/), [Jeff Donahue](http://jeffdonahue.com/), [Jon Long](https://github.com/longjon), [Ross Girshick](http://www.cs.berkeley.edu/~rbg/), [Sergey Karayev](http://sergeykarayev.com/), [Sergio Guadarrama](http://www.eecs.berkeley.edu/~sguada/). + +Additionally, the open-source community plays a large and growing role in Caffe's development. +Check out the Github [project pulse](https://github.com/BVLC/caffe/pulse) for recent activity, and the [contributors](https://github.com/BVLC/caffe/graphs/contributors) for a sorted list. + +We sincerely appreciate your interest and contributions! +If you'd like to contribute, please read the [developing & contributing](development.html) guide. + +## Contacting us + +All questions about usage, installation, code, and applications should be searched for and asked on the [caffe-users mailing list](https://groups.google.com/forum/#!forum/caffe-users). + +All development discussion should be carried out at [GitHub Issues](https://github.com/BVLC/caffe/issues). + +If you have a proposal that may not be suited for public discussion *and an ability to act on it*, please email us [directly](mailto:caffe-dev@googlegroups.com). +Requests for features, explanations, or personal help will be ignored; post such matters publicly as issues. + +The core Caffe developers may be able to provide [consulting services](mailto:caffe-coldpress@googlegroups.com) for appropriate projects. diff --git a/caffe-crfrnn/docs/installation.md b/caffe-crfrnn/docs/installation.md new file mode 100644 index 00000000..c667cd8c --- /dev/null +++ b/caffe-crfrnn/docs/installation.md @@ -0,0 +1,302 @@ +--- +title: Installation +--- + +# Installation + +Prior to installing, it is best to read through this guide and take note of the details for your platform. +We have installed Caffe on Ubuntu 14.04, Ubuntu 12.04, OS X 10.9, and OS X 10.8. + +- [Prerequisites](#prerequisites) +- [Compilation](#compilation) +- [Hardware questions](#hardware_questions) + +## Prerequisites + +Caffe depends on several software packages. + +* [CUDA](https://developer.nvidia.com/cuda-zone) library version 6.5 (recommended), 6.0, 5.5, or 5.0 and the latest driver version for CUDA 6 or 319.* for CUDA 5 (and NOT 331.*) +* [BLAS](http://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) (provided via ATLAS, MKL, or OpenBLAS). +* [OpenCV](http://opencv.org/). +* [Boost](http://www.boost.org/) (>= 1.55, although only 1.55 and 1.56 are tested) +* `glog`, `gflags`, `protobuf`, `leveldb`, `snappy`, `hdf5`, `lmdb` +* For the Python wrapper + * `Python 2.7`, `numpy (>= 1.7)`, boost-provided `boost.python` +* For the MATLAB wrapper + * MATLAB with the `mex` compiler. + +**cuDNN Caffe**: for fastest operation Caffe is accelerated by drop-in integration of [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). To speed up your Caffe models, install cuDNN then uncomment the `USE_CUDNN := 1` flag in `Makefile.config` when installing Caffe. Acceleration is automatic. + +**CPU-only Caffe**: for cold-brewed CPU-only Caffe uncomment the `CPU_ONLY := 1` flag in `Makefile.config` to configure and build Caffe without CUDA. This is helpful for cloud or cluster deployment. + +### CUDA and BLAS + +Caffe requires the CUDA `nvcc` compiler to compile its GPU code and CUDA driver for GPU operation. +To install CUDA, go to the [NVIDIA CUDA website](https://developer.nvidia.com/cuda-downloads) and follow installation instructions there. Install the library and the latest standalone driver separately; the driver bundled with the library is usually out-of-date. **Warning!** The 331.* CUDA driver series has a critical performance issue: do not use it. + +For best performance, Caffe can be accelerated by [NVIDIA cuDNN](https://developer.nvidia.com/cudnn). Register for free at the cuDNN site, install it, then continue with these installation instructions. To compile with cuDNN set the `USE_CUDNN := 1` flag set in your `Makefile.config`. + +Caffe requires BLAS as the backend of its matrix and vector computations. +There are several implementations of this library. +The choice is yours: + +* [ATLAS](http://math-atlas.sourceforge.net/): free, open source, and so the default for Caffe. + + Ubuntu: `sudo apt-get install libatlas-base-dev` + + CentOS/RHEL/Fedora: `sudo yum install atlas-devel` + + OS X: already installed as the [Accelerate / vecLib Framework](https://developer.apple.com/library/mac/documentation/Darwin/Reference/ManPages/man7/Accelerate.7.html). +* [Intel MKL](http://software.intel.com/en-us/intel-mkl): commercial and optimized for Intel CPUs, with a free trial and [student](http://software.intel.com/en-us/intel-education-offerings) licenses. + 1. Install MKL. + 2. Set `BLAS := mkl` in `Makefile.config` +* [OpenBLAS](http://www.openblas.net/): free and open source; this optimized and parallel BLAS could require more effort to install, although it might offer a speedup. + 1. Install OpenBLAS + 2. Set `BLAS := open` in `Makefile.config` + +### Python and/or MATLAB wrappers (optional) + +#### Python + +The main requirements are `numpy` and `boost.python` (provided by boost). `pandas` is useful too and needed for some examples. + +You can install the dependencies with + + for req in $(cat requirements.txt); do sudo pip install $req; done + +but we highly recommend first installing the [Anaconda](https://store.continuum.io/cshop/anaconda/) Python distribution, which provides most of the necessary packages, as well as the `hdf5` library dependency. + +For **Ubuntu**, if you use the default Python you will need to `sudo apt-get install` the `python-dev` package to have the Python headers for building the wrapper. + +For **Fedora**, if you use the default Python you will need to `sudo yum install` the `python-devel` package to have the Python headers for building the wrapper. + +For **OS X**, Anaconda is the preferred Python. If you decide against it, please use Homebrew -- but beware of potential linking errors! + +To import the `caffe` Python module after completing the installation, add the module directory to your `$PYTHONPATH` by `export PYTHONPATH=/path/to/caffe/python:$PYTHONPATH` or the like. You should not import the module in the `caffe/python/caffe` directory! + +*Caffe's Python interface works with Python 2.7. Python 3 or earlier Pythons are your own adventure.* + +#### MATLAB + +Install MATLAB, and make sure that its `mex` is in your `$PATH`. + +*Caffe's MATLAB interface works with versions 2012b, 2013a/b, and 2014a.* + +### The rest of the dependencies + +#### Linux + +On **Ubuntu**, most of the dependencies can be installed with + + sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libboost-all-dev libhdf5-serial-dev + +and for **Ubuntu 14.04** the rest of the dependencies can be installed with + + sudo apt-get install libgflags-dev libgoogle-glog-dev liblmdb-dev protobuf-compiler + +Keep reading to find out how to manually build and install the Google flags library, Google logging library and LMDB on **Ubuntu 12.04**. + +On **CentOS / RHEL / Fedora**, most of the dependencies can be installed with + + sudo yum install protobuf-devel leveldb-devel snappy-devel opencv-devel boost-devel hdf5-devel + +The Google flags library, Google logging library and LMDB already made their ways into newer versions of **CentOS / RHEL / Fedora** so it is better to first attempt to install them using `yum` + + sudo yum install gflags-devel glog-devel lmdb-devel + +**Finally** in case you couldn't find those extra libraries mentioned above in your distribution's repositories, here are the instructions to follow for manually building and installing them on **Ubuntu 12.04 / CentOS / RHEL / Fedora** (or practically on any Linux distribution) + + # glog + wget https://google-glog.googlecode.com/files/glog-0.3.3.tar.gz + tar zxvf glog-0.3.3.tar.gz + cd glog-0.3.3 + ./configure + make && make install + # gflags + wget https://github.com/schuhschuh/gflags/archive/master.zip + unzip master.zip + cd gflags-master + mkdir build && cd build + export CXXFLAGS="-fPIC" && cmake .. && make VERBOSE=1 + make && make install + # lmdb + git clone git://gitorious.org/mdb/mdb.git + cd mdb/libraries/liblmdb + make && make install + +Note that glog does not compile with the most recent gflags version (2.1), so before that is resolved you will need to build with glog first. + +#### OS X + +On **OS X**, we highly recommend using the [Homebrew](http://brew.sh/) package manager, and ideally starting from a clean install of the OS (or from a wiped `/usr/local`) to avoid conflicts. +In the following, we assume that you're using Anaconda Python and Homebrew. + +To install the OpenCV dependency, we'll need to provide an additional source for Homebrew: + + brew tap homebrew/science + +If using Anaconda Python, a modification is required to the OpenCV formula. +Do `brew edit opencv` and change the lines that look like the two lines below to exactly the two lines below. + + -DPYTHON_LIBRARY=#{py_prefix}/lib/libpython2.7.dylib + -DPYTHON_INCLUDE_DIR=#{py_prefix}/include/python2.7 + +**NOTE**: We find that everything compiles successfully if `$LD_LIBRARY_PATH` is not set at all, and `$DYLD_FALLBACK_LIBRARY_PATH` is set to to provide CUDA, Python, and other relevant libraries (e.g. `/usr/local/cuda/lib:$HOME/anaconda/lib:/usr/local/lib:/usr/lib`). +In other `ENV` settings, things may not work as expected. + +**NOTE**: There is currently a conflict between boost 1.56 and CUDA in some configurations. Check the [conflict description](https://github.com/BVLC/caffe/issues/1193#issuecomment-57491906) and try downgrading to 1.55. + +#### 10.8-specific Instructions + +Simply run the following: + + brew install --build-from-source boost boost-python + brew install --with-python protobuf + for x in snappy leveldb gflags glog szip lmdb homebrew/science/opencv; do brew install $x; done + +Building boost from source is needed to link against your local Python (exceptions might be raised during some OS X installs, but **ignore** these and continue). If you do not need the Python wrapper, simply doing `brew install boost` is fine. + +**Note** that the HDF5 dependency is provided by Anaconda Python in this case. +If you're not using Anaconda, include `hdf5` in the list above. + +#### 10.9-specific Instructions + +In OS X 10.9, clang++ is the default C++ compiler and uses `libc++` as the standard library. +However, NVIDIA CUDA (even version 6.0) currently links only with `libstdc++`. +This makes it necessary to change the compilation settings for each of the dependencies. + +We do this by modifying the Homebrew formulae before installing any packages. +Make sure that Homebrew doesn't install any software dependencies in the background; all packages must be linked to `libstdc++`. + +The prerequisite Homebrew formulae are + + boost snappy leveldb protobuf gflags glog szip lmdb homebrew/science/opencv + +For each of these formulas, `brew edit FORMULA`, and add the ENV definitions as shown: + + def install + # ADD THE FOLLOWING: + ENV.append "CXXFLAGS", "-stdlib=libstdc++" + ENV.append "CFLAGS", "-stdlib=libstdc++" + ENV.append "LDFLAGS", "-stdlib=libstdc++ -lstdc++" + # The following is necessary because libtool likes to strip LDFLAGS: + ENV["CXX"] = "/usr/bin/clang++ -stdlib=libstdc++" + ... + +To edit the formulae in turn, run + + for x in snappy leveldb protobuf gflags glog szip boost boost-python lmdb homebrew/science/opencv; do brew edit $x; done + +After this, run + + for x in snappy leveldb gflags glog szip lmdb homebrew/science/opencv; do brew uninstall $x; brew install --build-from-source --fresh -vd $x; done + brew uninstall protobuf; brew install --build-from-source --with-python --fresh -vd protobuf + brew install --build-from-source --fresh -vd boost boost-python + +**Note** that `brew install --build-from-source --fresh -vd boost` is fine if you do not need the Caffe Python wrapper. + +**Note** that the HDF5 dependency is provided by Anaconda Python in this case. +If you're not using Anaconda, include `hdf5` in the list above. + +**Note** that in order to build the Caffe Python wrappers you must install `boost` and `boost-python`: + + brew install --build-from-source --fresh -vd boost boost-python + +**Note** that Homebrew maintains itself as a separate git repository and making the above `brew edit FORMULA` changes will change files in your local copy of homebrew's master branch. By default, this will prevent you from updating Homebrew using `brew update`, as you will get an error message like the following: + + $ brew update + error: Your local changes to the following files would be overwritten by merge: + Library/Formula/lmdb.rb + Please, commit your changes or stash them before you can merge. + Aborting + Error: Failure while executing: git pull -q origin refs/heads/master:refs/remotes/origin/master + +One solution is to commit your changes to a separate Homebrew branch, run `brew update`, and rebase your changes onto the updated master. You'll have to do this both for the main Homebrew repository in `/usr/local/` and the Homebrew science repository that contains OpenCV in `/usr/local/Library/Taps/homebrew/homebrew-science`, as follows: + + cd /usr/local + git checkout -b caffe + git add . + git commit -m "Update Caffe dependencies to use libstdc++" + cd /usr/local/Library/Taps/homebrew/homebrew-science + git checkout -b caffe + git add . + git commit -m "Update Caffe dependencies" + +Then, whenever you want to update homebrew, switch back to the master branches, do the update, rebase the caffe branches onto master and fix any conflicts: + + # Switch batch to homebrew master branches + cd /usr/local + git checkout master + cd /usr/local/Library/Taps/homebrew/homebrew-science + git checkout master + + # Update homebrew; hopefully this works without errors! + brew update + + # Switch back to the caffe branches with the forumlae that you modified earlier + cd /usr/local + git rebase master caffe + # Fix any merge conflicts and commit to caffe branch + cd /usr/local/Library/Taps/homebrew/homebrew-science + git rebase master caffe + # Fix any merge conflicts and commit to caffe branch + + # Done! + +At this point, you should be running the latest Homebrew packages and your Caffe-related modifications will remain in place. + +#### Windows + +There is an unofficial Windows port of Caffe at [niuzhiheng/caffe:windows](https://github.com/niuzhiheng/caffe). Thanks [@niuzhiheng](https://github.com/niuzhiheng)! + +## Compilation + +Now that you have the prerequisites, edit your `Makefile.config` to change the paths for your setup (you should especially uncomment and set `BLAS_LIB` accordingly on distributions like **CentOS / RHEL / Fedora** where ATLAS is installed under `/usr/lib[64]/atlas`) +The defaults should work, but uncomment the relevant lines if using Anaconda Python. + + cp Makefile.config.example Makefile.config + # Adjust Makefile.config (for example, if using Anaconda Python) + make all + make test + make runtest + +To compile with cuDNN acceleration, you should uncomment the `USE_CUDNN := 1` switch in `Makefile.config`. + +If there is no GPU in your machine, you should switch to CPU-only Caffe by uncommenting `CPU_ONLY := 1` in `Makefile.config`. + +To compile the Python and MATLAB wrappers do `make pycaffe` and `make matcaffe` respectively. +Be sure to set your MATLAB and Python paths in `Makefile.config` first! + +*Distribution*: run `make distribute` to create a `distribute` directory with all the Caffe headers, compiled libraries, binaries, etc. needed for distribution to other machines. + +*Speed*: for a faster build, compile in parallel by doing `make all -j8` where 8 is the number of parallel threads for compilation (a good choice for the number of threads is the number of cores in your machine). + +Now that you have installed Caffe, check out the [MNIST tutorial](gathered/examples/mnist.html) and the [reference ImageNet model tutorial](gathered/examples/imagenet.html). + +### Compilation using CMake (beta) + +In lieu of manually editing `Makefile.config` to tell Caffe where dependencies are located, Caffe also provides a CMake-based build system (currently in "beta"). +It requires CMake version >= 2.8.8. +The basic installation steps are as follows: + + mkdir build + cd build + cmake .. + make all + make runtest + +#### Ubuntu 12.04 + +Note that in Ubuntu 12.04, Aptitude will install version CMake 2.8.7 by default, which is not supported by Caffe's CMake build (requires at least 2.8.8). +As a workaround, if you are using Ubuntu 12.04 you can try the following steps to install (or upgrade to) CMake 2.8.9: + + sudo add-apt-repository ppa:ubuntu-sdk-team/ppa -y + sudo apt-get -y update + sudo apt-get install cmake + +## Hardware Questions + +**Laboratory Tested Hardware**: Berkeley Vision runs Caffe with K40s, K20s, and Titans including models at ImageNet/ILSVRC scale. We also run on GTX series cards and GPU-equipped MacBook Pros. We have not encountered any trouble in-house with devices with CUDA capability >= 3.0. All reported hardware issues thus-far have been due to GPU configuration, overheating, and the like. + +**CUDA compute capability**: devices with compute capability <= 2.0 may have to reduce CUDA thread numbers and batch sizes due to hardware constraints. Your mileage may vary. + +Once installed, check your times against our [reference performance numbers](performance_hardware.html) to make sure everything is configured properly. + +Refer to the project's issue tracker for [hardware/compatibility](https://github.com/BVLC/caffe/issues?labels=hardware%2Fcompatibility&page=1&state=open). diff --git a/caffe-crfrnn/docs/model_zoo.md b/caffe-crfrnn/docs/model_zoo.md new file mode 100644 index 00000000..358bbb7f --- /dev/null +++ b/caffe-crfrnn/docs/model_zoo.md @@ -0,0 +1,55 @@ +--- +title: Model Zoo +--- +# Caffe Model Zoo + +Lots of people have used Caffe to train models of different architectures and applied to different problems, ranging from simple regression to AlexNet-alikes to Siamese networks for image similarity to speech applications. +To lower the friction of sharing these models, we introduce the model zoo framework: + +- A standard format for packaging Caffe model info. +- Tools to upload/download model info to/from Github Gists, and to download trained `.caffemodel` binaries. +- A central wiki page for sharing model info Gists. + +## Where to get trained models + +First of all, we provide some trained models out of the box. +Each one of these can be downloaded by running `scripts/download_model_binary.py ` where `` is specified below: + +- **BVLC Reference CaffeNet** in `models/bvlc_reference_caffenet`: AlexNet trained on ILSVRC 2012, with a minor variation from the version as described in the NIPS 2012 paper. (Trained by Jeff Donahue @jeffdonahue) +- **BVLC AlexNet** in `models/bvlc_alexnet`: AlexNet trained on ILSVRC 2012, almost exactly as described in NIPS 2012. (Trained by Evan Shelhamer @shelhamer) +- **BVLC Reference R-CNN ILSVRC-2013** in `models/bvlc_reference_rcnn_ilsvrc13`: pure Caffe implementation of [R-CNN](https://github.com/rbgirshick/rcnn). (Trained by Ross Girshick @rbgirshick) +- **BVLC GoogleNet** in `models/bvlc_googlenet`: GoogleNet trained on ILSVRC 2012, almost exactly as described in [GoogleNet](http://arxiv.org/abs/1409.4842). (Trained by Sergio Guadarrama @sguada) + +User-provided models are posted to a public-editable [wiki page](https://github.com/BVLC/caffe/wiki/Model-Zoo). + +## Model info format + +A caffe model is distributed as a directory containing: + +- Solver/model prototxt(s) +- `readme.md` containing + - YAML frontmatter + - Caffe version used to train this model (tagged release or commit hash). + - [optional] file URL and SHA1 of the trained `.caffemodel`. + - [optional] github gist id. + - Information about what data the model was trained on, modeling choices, etc. + - License information. +- [optional] Other helpful scripts. + +## Hosting model info + +Github Gist is a good format for model info distribution because it can contain multiple files, is versionable, and has in-browser syntax highlighting and markdown rendering. + +- `scripts/upload_model_to_gist.sh `: uploads non-binary files in the model directory as a Github Gist and prints the Gist ID. If `gist_id` is already part of the `/readme.md` frontmatter, then updates existing Gist. + +Try doing `scripts/upload_model_to_gist.sh models/bvlc_alexnet` to test the uploading (don't forget to delete the uploaded gist afterward). + +Downloading model info is done just as easily with `scripts/download_model_from_gist.sh `. + +### Hosting trained models + +It is up to the user where to host the `.caffemodel` file. +We host our BVLC-provided models on our own server. +Dropbox also works fine (tip: make sure that `?dl=1` is appended to the end of the URL). + +- `scripts/download_model_binary.py `: downloads the `.caffemodel` from the URL specified in the `/readme.md` frontmatter and confirms SHA1. diff --git a/caffe-crfrnn/docs/performance_hardware.md b/caffe-crfrnn/docs/performance_hardware.md new file mode 100644 index 00000000..b35246fe --- /dev/null +++ b/caffe-crfrnn/docs/performance_hardware.md @@ -0,0 +1,73 @@ +--- +title: Performance and Hardware Configuration +--- + +# Performance and Hardware Configuration + +To measure performance on different NVIDIA GPUs we use CaffeNet, the Caffe reference ImageNet model. + +For training, each time point is 20 iterations/minibatches of 256 images for 5,120 images total. For testing, a 50,000 image validation set is classified. + +**Acknowledgements**: BVLC members are very grateful to NVIDIA for providing several GPUs to conduct this research. + +## NVIDIA K40 + +Performance is best with ECC off and boost clock enabled. While ECC makes a negligible difference in speed, disabling it frees ~1 GB of GPU memory. + +Best settings with ECC off and maximum clock speed in standard Caffe: + +* Training is 26.5 secs / 20 iterations (5,120 images) +* Testing is 100 secs / validation set (50,000 images) + +Best settings with Caffe + [cuDNN acceleration](http://nvidia.com/cudnn): + +* Training is 19.2 secs / 20 iterations (5,120 images) +* Testing is 60.7 secs / validation set (50,000 images) + +Other settings: + +* ECC on, max speed: training 26.7 secs / 20 iterations, test 101 secs / validation set +* ECC on, default speed: training 31 secs / 20 iterations, test 117 secs / validation set +* ECC off, default speed: training 31 secs / 20 iterations, test 118 secs / validation set + +### K40 configuration tips + +For maximum K40 performance, turn off ECC and boost the clock speed (at your own risk). + +To turn off ECC, do + + sudo nvidia-smi -i 0 --ecc-config=0 # repeat with -i x for each GPU ID + +then reboot. + +Set the "persistence" mode of the GPU settings by + + sudo nvidia-smi -pm 1 + +and then set the clock speed with + + sudo nvidia-smi -i 0 -ac 3004,875 # repeat with -i x for each GPU ID + +but note that this configuration resets across driver reloading / rebooting. Include these commands in a boot script to intialize these settings. For a simple fix, add these commands to `/etc/rc.local` (on Ubuntu). + +## NVIDIA Titan + +Training: 26.26 secs / 20 iterations (5,120 images). +Testing: 100 secs / validation set (50,000 images). + +cuDNN Training: 20.25 secs / 20 iterations (5,120 images). +cuDNN Testing: 66.3 secs / validation set (50,000 images). + + +## NVIDIA K20 + +Training: 36.0 secs / 20 iterations (5,120 images). +Testing: 133 secs / validation set (50,000 images). + +## NVIDIA GTX 770 + +Training: 33.0 secs / 20 iterations (5,120 images). +Testing: 129 secs / validation set (50,000 images). + +cuDNN Training: 24.3 secs / 20 iterations (5,120 images). +cuDNN Testing: 104 secs / validation set (50,000 images). diff --git a/caffe-crfrnn/docs/stylesheets/pygment_trac.css b/caffe-crfrnn/docs/stylesheets/pygment_trac.css new file mode 100644 index 00000000..c6a6452d --- /dev/null +++ b/caffe-crfrnn/docs/stylesheets/pygment_trac.css @@ -0,0 +1,69 @@ +.highlight { background: #ffffff; } +.highlight .c { color: #999988; font-style: italic } /* Comment */ +.highlight .err { color: #a61717; background-color: #e3d2d2 } /* Error */ +.highlight .k { font-weight: bold } /* Keyword */ +.highlight .o { font-weight: bold } /* Operator */ +.highlight .cm { color: #999988; font-style: italic } /* Comment.Multiline */ +.highlight .cp { color: #999999; font-weight: bold } /* Comment.Preproc */ +.highlight .c1 { color: #999988; font-style: italic } /* Comment.Single */ +.highlight .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */ +.highlight .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */ +.highlight .gd .x { color: #000000; background-color: #ffaaaa } /* Generic.Deleted.Specific */ +.highlight .ge { font-style: italic } /* Generic.Emph */ +.highlight .gr { color: #aa0000 } /* Generic.Error */ +.highlight .gh { color: #999999 } /* Generic.Heading */ +.highlight .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */ +.highlight .gi .x { color: #000000; background-color: #aaffaa } /* Generic.Inserted.Specific */ +.highlight .go { color: #888888 } /* Generic.Output */ +.highlight .gp { color: #555555 } /* Generic.Prompt */ +.highlight .gs { font-weight: bold } /* Generic.Strong */ +.highlight .gu { color: #800080; font-weight: bold; } /* Generic.Subheading */ +.highlight .gt { color: #aa0000 } /* Generic.Traceback */ +.highlight .kc { font-weight: bold } /* Keyword.Constant */ +.highlight .kd { font-weight: bold } /* Keyword.Declaration */ +.highlight .kn { font-weight: bold } /* Keyword.Namespace */ +.highlight .kp { font-weight: bold } /* Keyword.Pseudo */ +.highlight .kr { font-weight: bold } /* Keyword.Reserved */ +.highlight .kt { color: #445588; font-weight: bold } /* Keyword.Type */ +.highlight .m { color: #009999 } /* Literal.Number */ +.highlight .s { color: #d14 } /* Literal.String */ +.highlight .na { color: #008080 } /* Name.Attribute */ +.highlight .nb { color: #0086B3 } /* Name.Builtin */ +.highlight .nc { color: #445588; font-weight: bold } /* Name.Class */ +.highlight .no { color: #008080 } /* Name.Constant */ +.highlight .ni { color: #800080 } /* Name.Entity */ +.highlight .ne { color: #990000; font-weight: bold } /* Name.Exception */ +.highlight .nf { color: #990000; font-weight: bold } /* Name.Function */ +.highlight .nn { color: #555555 } /* Name.Namespace */ +.highlight .nt { color: #000080 } /* Name.Tag */ +.highlight .nv { color: #008080 } /* Name.Variable */ +.highlight .ow { font-weight: bold } /* Operator.Word */ +.highlight .w { color: #bbbbbb } /* Text.Whitespace */ +.highlight .mf { color: #009999 } /* Literal.Number.Float */ +.highlight .mh { color: #009999 } /* Literal.Number.Hex */ +.highlight .mi { color: #009999 } /* Literal.Number.Integer */ +.highlight .mo { color: #009999 } /* Literal.Number.Oct */ +.highlight .sb { color: #d14 } /* Literal.String.Backtick */ +.highlight .sc { color: #d14 } /* Literal.String.Char */ +.highlight .sd { color: #d14 } /* Literal.String.Doc */ +.highlight .s2 { color: #d14 } /* Literal.String.Double */ +.highlight .se { color: #d14 } /* Literal.String.Escape */ +.highlight .sh { color: #d14 } /* Literal.String.Heredoc */ +.highlight .si { color: #d14 } /* Literal.String.Interpol */ +.highlight .sx { color: #d14 } /* Literal.String.Other */ +.highlight .sr { color: #009926 } /* Literal.String.Regex */ +.highlight .s1 { color: #d14 } /* Literal.String.Single */ +.highlight .ss { color: #990073 } /* Literal.String.Symbol */ +.highlight .bp { color: #999999 } /* Name.Builtin.Pseudo */ +.highlight .vc { color: #008080 } /* Name.Variable.Class */ +.highlight .vg { color: #008080 } /* Name.Variable.Global */ +.highlight .vi { color: #008080 } /* Name.Variable.Instance */ +.highlight .il { color: #009999 } /* Literal.Number.Integer.Long */ + +.type-csharp .highlight .k { color: #0000FF } +.type-csharp .highlight .kt { color: #0000FF } +.type-csharp .highlight .nf { color: #000000; font-weight: normal } +.type-csharp .highlight .nc { color: #2B91AF } +.type-csharp .highlight .nn { color: #000000 } +.type-csharp .highlight .s { color: #A31515 } +.type-csharp .highlight .sc { color: #A31515 } diff --git a/caffe-crfrnn/docs/stylesheets/reset.css b/caffe-crfrnn/docs/stylesheets/reset.css new file mode 100644 index 00000000..6020b26f --- /dev/null +++ b/caffe-crfrnn/docs/stylesheets/reset.css @@ -0,0 +1,21 @@ +/* MeyerWeb Reset */ + +html, body, div, span, applet, object, iframe, +h1, h2, h3, h4, h5, h6, p, blockquote, pre, +a, abbr, acronym, address, big, cite, code, +del, dfn, em, img, ins, kbd, q, s, samp, +small, strike, strong, sub, sup, tt, var, +b, u, i, center, +dl, dt, dd, ol, ul, li, +fieldset, form, label, legend, +table, caption, tbody, tfoot, thead, tr, th, td, +article, aside, canvas, details, embed, +figure, figcaption, footer, header, hgroup, +menu, nav, output, ruby, section, summary, +time, mark, audio, video { + margin: 0; + padding: 0; + border: 0; + font: inherit; + vertical-align: baseline; +} diff --git a/caffe-crfrnn/docs/stylesheets/styles.css b/caffe-crfrnn/docs/stylesheets/styles.css new file mode 100644 index 00000000..2dbedb8a --- /dev/null +++ b/caffe-crfrnn/docs/stylesheets/styles.css @@ -0,0 +1,348 @@ +@import url(http://fonts.googleapis.com/css?family=PT+Serif|Open+Sans:600,400); + +body { + padding:10px 50px 0 0; + font-family: 'Open Sans', sans-serif; + font-size: 14px; + color: #232323; + background-color: #FBFAF7; + margin: 0; + line-height: 1.5rem; + -webkit-font-smoothing: antialiased; +} + +h1, h2, h3, h4, h5, h6 { + color:#232323; + margin:36px 0 10px; +} + +p, ul, ol, table, dl { + margin:0 0 22px; +} + +h1, h2, h3 { + font-family: 'PT Serif', serif; + line-height:1.3; + font-weight: normal; + display: block; + border-bottom: 1px solid #ccc; + padding-bottom: 5px; +} + +h1 { + font-size: 30px; +} + +h2 { + font-size: 24px; +} + +h3 { + font-size: 18px; +} + +h4, h5, h6 { + font-family: 'PT Serif', serif; + font-weight: 700; +} + +a { + color:#C30000; + text-decoration:none; +} + +a:hover { + text-decoration: underline; +} + +a small { + font-size: 12px; +} + +em { + font-style: italic; +} + +strong { + font-weight:700; +} + +ul { + padding-left: 25px; +} + +ol { + list-style: decimal; + padding-left: 20px; +} + +blockquote { + margin: 0; + padding: 0 0 0 20px; + font-style: italic; +} + +dl, dt, dd, dl p { + font-color: #444; +} + +dl dt { + font-weight: bold; +} + +dl dd { + padding-left: 20px; + font-style: italic; +} + +dl p { + padding-left: 20px; + font-style: italic; +} + +hr { + border:0; + background:#ccc; + height:1px; + margin:0 0 24px; +} + +/* Images */ + +img { + position: relative; + margin: 0 auto; + max-width: 650px; + padding: 5px; + margin: 10px 0 32px 0; + border: 1px solid #ccc; +} + +p img { + display: inline; + margin: 0; + padding: 0; + vertical-align: middle; + text-align: center; + border: none; +} + +/* Code blocks */ +code, pre { + font-family: monospace; + color:#000; + font-size:12px; + line-height: 14px; +} + +pre { + padding: 6px 12px; + background: #FDFEFB; + border-radius:4px; + border:1px solid #D7D8C8; + overflow: auto; + white-space: pre-wrap; + margin-bottom: 16px; +} + + +/* Tables */ +table { + width:100%; +} + +table { + border: 1px solid #ccc; + margin-bottom: 32px; + text-align: left; + } + +th { + font-family: 'Open Sans', sans-serif; + font-size: 18px; + font-weight: normal; + padding: 10px; + background: #232323; + color: #FDFEFB; + } + +td { + padding: 10px; + background: #ccc; + } + + +/* Wrapper */ +.wrapper { + width:960px; +} + + +/* Header */ + +header { + width:170px; + float:left; + position:fixed; + padding: 12px 25px 22px 50px; + margin: 24px 25px 0 0; +} + +p.header { + font-size: 14px; +} + +h1.header { + font-size: 30px; + font-weight: 300; + line-height: 1.3em; + margin-top: 0; +} + +a.name { + white-space: nowrap; +} + +header ul { + list-style:none; + padding:0; +} + +header li { + list-style-type: none; + width:132px; + height:15px; + margin-bottom: 12px; + line-height: 1em; + padding: 6px 6px 6px 7px; + background: #c30000; + border-radius:4px; + border:1px solid #555; +} + +header li:hover { + background: #dd0000; +} + +a.buttons { + color: #fff; + text-decoration: none; + font-weight: normal; + padding: 2px 2px 2px 22px; + height: 30px; +} + +a.github { + background: url(/images/GitHub-Mark-64px.png) no-repeat center left; + background-size: 15%; +} + +/* Section - for main page content */ + +section { + width:650px; + float:right; + padding-bottom:50px; +} + +p.footnote { + font-size: 12px; +} + + +/* Footer */ + +footer { + width:170px; + float:left; + position:fixed; + bottom:10px; + padding-left: 50px; +} + +@media print, screen and (max-width: 960px) { + + div.wrapper { + width:auto; + margin:0; + } + + header, section, footer { + float:none; + position:static; + width:auto; + } + + footer { + border-top: 1px solid #ccc; + margin:0 84px 0 50px; + padding:0; + } + + header { + padding-right:320px; + } + + section { + padding:20px 84px 20px 50px; + margin:0 0 20px; + } + + header a small { + display:inline; + } + + header ul { + position:absolute; + right:130px; + top:84px; + } +} + +@media print, screen and (max-width: 720px) { + body { + word-wrap:break-word; + } + + header { + padding:10px 20px 0; + margin-right: 0; + } + + section { + padding:10px 0 10px 20px; + margin:0 0 30px; + } + + footer { + margin: 0 0 0 30px; + } + + header ul, header p.view { + position:static; + } +} + +@media print, screen and (max-width: 480px) { + + header ul li.download { + display:none; + } + + footer { + margin: 0 0 0 20px; + } + + footer a{ + display:block; + } + +} + +@media print { + body { + padding:0.4in; + font-size:12pt; + color:#444; + } +} diff --git a/caffe-crfrnn/docs/tutorial/convolution.md b/caffe-crfrnn/docs/tutorial/convolution.md new file mode 100644 index 00000000..a02fe4ef --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/convolution.md @@ -0,0 +1,13 @@ +--- +title: Convolution +--- +# Caffeinated Convolution + +The Caffe strategy for convolution is to reduce the problem to matrix-matrix multiplication. +This linear algebra computation is highly-tuned in BLAS libraries and efficiently computed on GPU devices. + +For more details read Yangqing's [Convolution in Caffe: a memo](https://github.com/Yangqing/caffe/wiki/Convolution-in-Caffe:-a-memo). + +As it turns out, this same reduction was independently explored in the context of conv. nets by + +> K. Chellapilla, S. Puri, P. Simard, et al. High performance convolutional neural networks for document processing. In Tenth International Workshop on Frontiers in Handwriting Recognition, 2006. diff --git a/caffe-crfrnn/docs/tutorial/data.md b/caffe-crfrnn/docs/tutorial/data.md new file mode 100644 index 00000000..40605f7c --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/data.md @@ -0,0 +1,78 @@ +--- +title: Data +--- +# Data: Ins and Outs + +Data flows through Caffe as [Blobs](net_layer_blob.html#blob-storage-and-communication). +Data layers load input and save output by converting to and from Blob to other formats. +Common transformations like mean-subtraction and feature-scaling are done by data layer configuration. +New input types are supported by developing a new data layer -- the rest of the Net follows by the modularity of the Caffe layer catalogue. + +This data layer definition + + layers { + name: "mnist" + # DATA layer loads leveldb or lmdb storage DBs for high-throughput. + type: DATA + # the 1st top is the data itself: the name is only convention + top: "data" + # the 2nd top is the ground truth: the name is only convention + top: "label" + # the DATA layer configuration + data_param { + # path to the DB + source: "examples/mnist/mnist_train_lmdb" + # type of DB: LEVELDB or LMDB (LMDB supports concurrent reads) + backend: LMDB + # batch processing improves efficiency. + batch_size: 64 + } + # common data transformations + transform_param { + # feature scaling coefficient: this maps the [0, 255] MNIST data to [0, 1] + scale: 0.00390625 + } + } + +loads the MNIST digits. + +**Tops and Bottoms**: A data layer makes **top** blobs to output data to the model. +It does not have **bottom** blobs since it takes no input. + +**Data and Label**: a data layer has at least one top canonically named **data**. +For ground truth a second top can be defined that is canonically named **label**. +Both tops simply produce blobs and there is nothing inherently special about these names. +The (data, label) pairing is a convenience for classification models. + +**Transformations**: data preprocessing is parametrized by transformation messages within the data layer definition. + + layers { + name: "data" + type: DATA + [...] + transform_param { + scale: 0.1 + mean_file_size: mean.binaryproto + # for images in particular horizontal mirroring and random cropping + # can be done as simple data augmentations. + mirror: 1 # 1 = on, 0 = off + # crop a `crop_size` x `crop_size` patch: + # - at random during training + # - from the center during testing + crop_size: 227 + } + } + +**Prefetching**: for throughput data layers fetch the next batch of data and prepare it in the background while the Net computes the current batch. + +**Multiple Inputs**: a Net can have multiple inputs of any number and type. Define as many data layers as needed giving each a unique name and top. Multiple inputs are useful for non-trivial ground truth: one data layer loads the actual data and the other data layer loads the ground truth in lock-step. In this arrangement both data and label can be any 4D array. Further applications of multiple inputs are found in multi-modal and sequence models. In these cases you may need to implement your own data preparation routines or a special data layer. + +*Improvements to data processing to add formats, generality, or helper utilities are welcome!* + +## Formats + +Refer to the layer catalogue of [data layers](layers.html#data-layers) for close-ups on each type of data Caffe understands. + +## Deployment Input + +For on-the-fly computation deployment Nets define their inputs by `input` fields: these Nets then accept direct assignment of data for online or interactive computation. diff --git a/caffe-crfrnn/docs/tutorial/fig/.gitignore b/caffe-crfrnn/docs/tutorial/fig/.gitignore new file mode 100644 index 00000000..e69de29b diff --git a/caffe-crfrnn/docs/tutorial/forward_backward.md b/caffe-crfrnn/docs/tutorial/forward_backward.md new file mode 100644 index 00000000..f58b9cac --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/forward_backward.md @@ -0,0 +1,37 @@ +--- +title: Forward and Backward for Inference and Learning +--- +# Forward and Backward + +The forward and backward passes are the essential computations of a [Net](net_layer_blob.html). + +Forward and Backward + +Let's consider a simple logistic regression classifier. + +The **forward** pass computes the output given the input for inference. +In forward Caffe composes the computation of each layer to compute the "function" represented by the model. +This pass goes from bottom to top. + +Forward pass + +The data $x$ is passed through an inner product layer for $g(x)$ then through a softmax for $h(g(x))$ and softmax loss to give $f_W(x)$. + +The **backward** pass computes the gradient given the loss for learning. +In backward Caffe reverse-composes the gradient of each layer to compute the gradient of the whole model by automatic differentiation. +This is back-propagation. +This pass goes from top to bottom. + +Backward pass + +The backward pass begins with the loss and computes the gradient with respect to the output $\frac{\partial f_W}{\partial h}$. The gradient with respect to the rest of the model is computed layer-by-layer through the chain rule. Layers with parameters, like the `INNER_PRODUCT` layer, compute the gradient with respect to their parameters $\frac{\partial f_W}{\partial W_{\text{ip}}}$ during the backward step. + +These computations follow immediately from defining the model: Caffe plans and carries out the forward and backward passes for you. + +- The `Net::Forward()` and `Net::Backward()` methods carry out the respective passes while `Layer::Forward()` and `Layer::Backward()` compute each step. +- Every layer type has `forward_{cpu,gpu}()` and `backward_{cpu,gpu}` methods to compute its steps according to the mode of computation. A layer may only implement CPU or GPU mode due to constraints or convenience. + +The [Solver](solver.html) optimizes a model by first calling forward to yield the output and loss, then calling backward to generate the gradient of the model, and then incorporating the gradient into a weight update that attempts to minimize the loss. Division of labor between the Solver, Net, and Layer keep Caffe modular and open to development. + +For the details of the forward and backward steps of Caffe's layer types, refer to the [layer catalogue](layers.html). + diff --git a/caffe-crfrnn/docs/tutorial/index.md b/caffe-crfrnn/docs/tutorial/index.md new file mode 100644 index 00000000..7d4e77b1 --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/index.md @@ -0,0 +1,51 @@ +--- +title: Caffe Tutorial +--- +# Caffe Tutorial + +Caffe is a deep learning framework and this tutorial explains its philosophy, architecture, and usage. +This is a practical guide and framework introduction, so the full frontier, context, and history of deep learning cannot be covered here. +While explanations will be given where possible, a background in machine learning and neural networks is helpful. + +## Philosophy + +In one sip, Caffe is brewed for + +- Expression: models and optimizations are defined as plaintext schemas instead of code. +- Speed: for research and industry alike speed is crucial for state-of-the-art models and massive data. +- Modularity: new tasks and settings require flexibility and extension. +- Openness: scientific and applied progress call for common code, reference models, and reproducibility. +- Community: academic research, startup prototypes, and industrial applications all share strength by joint discussion and development in a BSD-2 project. + +and these principles direct the project. + +## Tour + +- [Nets, Layers, and Blobs](net_layer_blob.html): the anatomy of a Caffe model. +- [Forward / Backward](forward_backward.html): the essential computations of layered compositional models. +- [Loss](loss.html): the task to be learned is defined by the loss. +- [Solver](solver.html): the solver coordinates model optimization. +- [Layer Catalogue](layers.html): the layer is the fundamental unit of modeling and computation -- Caffe's catalogue includes layers for state-of-the-art models. +- [Interfaces](interfaces.html): command line, Python, and MATLAB Caffe. +- [Data](data.html): how to caffeinate data for model input. + +For a closer look at a few details: + +- [Caffeinated Convolution](convolution.html): how Caffe computes convolutions. + +## Deeper Learning + +There are helpful references freely online for deep learning that complement our hands-on tutorial. +These cover introductory and advanced material, background and history, and the latest advances. + +The [Tutorial on Deep Learning for Vision](https://sites.google.com/site/deeplearningcvpr2014/) from CVPR '14 is a good companion tutorial for researchers. +Once you have the framework and practice foundations from the Caffe tutorial, explore the fundamental ideas and advanced research directions in the CVPR '14 tutorial. + +A broad introduction is given in the free online draft of [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com/index.html) by Michael Nielsen. In particular the chapters on using neural nets and how backpropagation works are helpful if you are new to the subject. + +These recent academic tutorials cover deep learning for researchers in machine learning and vision: + +- [Deep Learning Tutorial](http://www.cs.nyu.edu/~yann/talks/lecun-ranzato-icml2013.pdf) by Yann LeCun (NYU, Facebook) and Marc'Aurelio Ranzato (Facebook). ICML 2013 tutorial. +- [LISA Deep Learning Tutorial](http://deeplearning.net/tutorial/deeplearning.pdf) by the LISA Lab directed by Yoshua Bengio (U. Montréal). + +For an exposition of neural networks in circuits and code, check out [Understanding Neural Networks from a Programmer's Perspective](http://karpathy.github.io/neuralnets/) by Andrej Karpathy (Stanford). diff --git a/caffe-crfrnn/docs/tutorial/interfaces.md b/caffe-crfrnn/docs/tutorial/interfaces.md new file mode 100644 index 00000000..6b0ec347 --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/interfaces.md @@ -0,0 +1,68 @@ +--- +title: Interfaces +--- +# Interfaces + +Caffe has command line, Python, and MATLAB interfaces for day-to-day usage, interfacing with research code, and rapid prototyping. While Caffe is a C++ library at heart and it exposes a modular interface for development, not every occasion calls for custom compilation. The cmdcaffe, pycaffe, and matcaffe interfaces are here for you. + +## Command Line + +The command line interface -- cmdcaffe -- is the `caffe` tool for model training, scoring, and diagnostics. Run `caffe` without any arguments for help. This tool and others are found in caffe/build/tools. (The following example calls require completing the LeNet / MNIST example first.) + +**Training**: `caffe train` learns models from scratch, resumes learning from saved snapshots, and fine-tunes models to new data and tasks. All training requires a solver configuration through the `-solver solver.prototxt` argument. Resuming requires the `-snapshot model_iter_1000.solverstate` argument to load the solver snapshot. Fine-tuning requires the `-weights model.caffemodel` argument for the model initialization. + + # train LeNet + caffe train -solver examples/mnist/lenet_solver.prototxt + # train on GPU 2 + caffe train -solver examples/mnist/lenet_solver.prototxt -gpu 2 + # resume training from the half-way point snapshot + caffe train -solver examples/mnist/lenet_solver.prototxt -snapshot examples/mnist/lenet_iter_5000.solverstate + +For a full example of fine-tuning, see examples/finetuning_on_flickr_style, but the training call alone is + + # fine-tune CaffeNet model weights for style recognition + caffe train -solver examples/finetuning_on_flickr_style/solver.prototxt -weights models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel + +**Testing**: `caffe test` scores models by running them in the test phase and reports the net output as its score. The net architecture must be properly defined to output an accuracy measure or loss as its output. The per-batch score is reported and then the grand average is reported last. + + # + # score the learned LeNet model on the validation set as defined in the model architeture lenet_train_test.prototxt + caffe test -model examples/mnist/lenet_train_test.prototxt -weights examples/mnist/lenet_iter_10000 -gpu 0 -iterations 100 + +**Benchmarking**: `caffe time` benchmarks model execution layer-by-layer through timing and synchronization. This is useful to check system performance and measure relative execution times for models. + + # (These example calls require you complete the LeNet / MNIST example first.) + # time LeNet training on CPU for 10 iterations + caffe time -model examples/mnist/lenet_train_test.prototxt -iterations 10 + # time a model architecture with the given weights on the first GPU for 10 iterations + # time LeNet training on GPU for the default 50 iterations + caffe time -model examples/mnist/lenet_train_test.prototxt -gpu 0 + +**Diagnostics**: `caffe device_query` reports GPU details for reference and checking device ordinals for running on a given device in multi-GPU machines. + + # query the first device + caffe device_query -gpu 0 + +## Python + +The Python interface -- pycaffe -- is the `caffe` module and its scripts in caffe/python. `import caffe` to load models, do forward and backward, handle IO, visualize networks, and even instrument model solving. All model data, derivatives, and parameters are exposed for reading and writing. + +- `caffe.Net` is the central interface for loading, configuring, and running models. `caffe.Classsifier` and `caffe.Detector` provide convenience interfaces for common tasks. +- `caffe.SGDSolver` exposes the solving interface. +- `caffe.io` handles input / output with preprocessing and protocol buffers. +- `caffe.draw` visualizes network architectures. +- Caffe blobs are exposed as numpy ndarrays for ease-of-use and efficiency. + +Tutorial IPython notebooks are found in caffe/examples: do `ipython notebook caffe/examples` to try them. For developer reference docstrings can be found throughout the code. + +Compile pycaffe by `make pycaffe`. The module dir caffe/python/caffe should be installed in your PYTHONPATH for `import caffe`. + +## MATLAB + +The MATLAB interface -- matcaffe -- is the `caffe` mex and its helper m-files in caffe/matlab. Load models, do forward and backward, extract output and read-only model weights, and load the binaryproto format mean as a matrix. + +A MATLAB demo is in caffe/matlab/caffe/matcaffe_demo.m + +Note that MATLAB matrices and memory are in column-major layout counter to Caffe's row-major layout! Double-check your work accordingly. + +Compile matcaffe by `make matcaffe`. diff --git a/caffe-crfrnn/docs/tutorial/layers.md b/caffe-crfrnn/docs/tutorial/layers.md new file mode 100644 index 00000000..5f8f519c --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/layers.md @@ -0,0 +1,468 @@ +--- +title: Layer Catalogue +--- +# Layers + +To create a Caffe model you need to define the model architecture in a protocol buffer definition file (prototxt). + +Caffe layers and their parameters are defined in the protocol buffer definitions for the project in [caffe.proto](https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto). The latest definitions are in the [dev caffe.proto](https://github.com/BVLC/caffe/blob/dev/src/caffe/proto/caffe.proto). + +TODO complete list of layers linking to headings + +### Vision Layers + +* Header: `./include/caffe/vision_layers.hpp` + +Vision layers usually take *images* as input and produce other *images* as output. +A typical "image" in the real-world may have one color channel ($$c = 1$$), as in a grayscale image, or three color channels ($$c = 3$$) as in an RGB (red, green, blue) image. +But in this context, the distinguishing characteristic of an image is its spatial structure: usually an image has some non-trivial height $$h > 1$$ and width $$w > 1$$. +This 2D geometry naturally lends itself to certain decisions about how to process the input. +In particular, most of the vision layers work by applying a particular operation to some region of the input to produce a corresponding region of the output. +In contrast, other layers (with few exceptions) ignore the spatial structure of the input, effectively treating it as "one big vector" with dimension $$chw$$. + + +#### Convolution + +* LayerType: `CONVOLUTION` +* CPU implementation: `./src/caffe/layers/convolution_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/convolution_layer.cu` +* Parameters (`ConvolutionParameter convolution_param`) + - Required + - `num_output` (`c_o`): the number of filters + - `kernel_size` (or `kernel_h` and `kernel_w`): specifies height and width of each filter + - Strongly Recommended + - `weight_filler` [default `type: 'constant' value: 0`] + - Optional + - `bias_term` [default `true`]: specifies whether to learn and apply a set of additive biases to the filter outputs + - `pad` (or `pad_h` and `pad_w`) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input + - `stride` (or `stride_h` and `stride_w`) [default 1]: specifies the intervals at which to apply the filters to the input + - `group` (g) [default 1]: If g > 1, we restrict the connectivity of each filter to a subset of the input. Specifically, the input and output channels are separated into g groups, and the $$i$$th output group channels will be only connected to the $$i$$th input group channels. +* Input + - `n * c_i * h_i * w_i` +* Output + - `n * c_o * h_o * w_o`, where `h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1` and `w_o` likewise. +* Sample (as seen in `./examples/imagenet/imagenet_train_val.prototxt`) + + layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 # learning rate multiplier for the filters + blobs_lr: 2 # learning rate multiplier for the biases + weight_decay: 1 # weight decay multiplier for the filters + weight_decay: 0 # weight decay multiplier for the biases + convolution_param { + num_output: 96 # learn 96 filters + kernel_size: 11 # each filter is 11x11 + stride: 4 # step 4 pixels between each filter application + weight_filler { + type: "gaussian" # initialize the filters from a Gaussian + std: 0.01 # distribution with stdev 0.01 (default mean: 0) + } + bias_filler { + type: "constant" # initialize the biases to zero (0) + value: 0 + } + } + } + +The `CONVOLUTION` layer convolves the input image with a set of learnable filters, each producing one feature map in the output image. + +#### Pooling + +* LayerType: `POOLING` +* CPU implementation: `./src/caffe/layers/pooling_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/pooling_layer.cu` +* Parameters (`PoolingParameter pooling_param`) + - Required + - `kernel_size` (or `kernel_h` and `kernel_w`): specifies height and width of each filter + - Optional + - `pool` [default MAX]: the pooling method. Currently MAX, AVE, or STOCHASTIC + - `pad` (or `pad_h` and `pad_w`) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input + - `stride` (or `stride_h` and `stride_w`) [default 1]: specifies the intervals at which to apply the filters to the input +* Input + - `n * c * h_i * w_i` +* Output + - `n * c * h_o * w_o`, where h_o and w_o are computed in the same way as convolution. +* Sample (as seen in `./examples/imagenet/imagenet_train_val.prototxt`) + + layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 # pool over a 3x3 region + stride: 2 # step two pixels (in the bottom blob) between pooling regions + } + } + +#### Local Response Normalization (LRN) + +* LayerType: `LRN` +* CPU Implementation: `./src/caffe/layers/lrn_layer.cpp` +* CUDA GPU Implementation: `./src/caffe/layers/lrn_layer.cu` +* Parameters (`LRNParameter lrn_param`) + - Optional + - `local_size` [default 5]: the number of channels to sum over (for cross channel LRN) or the side length of the square region to sum over (for within channel LRN) + - `alpha` [default 1]: the scaling parameter (see below) + - `beta` [default 5]: the exponent (see below) + - `norm_region` [default `ACROSS_CHANNELS`]: whether to sum over adjacent channels (`ACROSS_CHANNELS`) or nearby spatial locaitons (`WITHIN_CHANNEL`) + +The local response normalization layer performs a kind of "lateral inhibition" by normalizing over local input regions. In `ACROSS_CHANNELS` mode, the local regions extend across nearby channels, but have no spatial extent (i.e., they have shape `local_size x 1 x 1`). In `WITHIN_CHANNEL` mode, the local regions extend spatially, but are in separate channels (i.e., they have shape `1 x local_size x local_size`). Each input value is divided by $$(1 + (\alpha/n) \sum_i x_i^2)^\beta$$, where $$n$$ is the size of each local region, and the sum is taken over the region centered at that value (zero padding is added where necessary). + +#### im2col + +`IM2COL` is a helper for doing the image-to-column transformation that you most likely do not need to know about. This is used in Caffe's original convolution to do matrix multiplication by laying out all patches into a matrix. + +### Loss Layers + +Loss drives learning by comparing an output to a target and assigning cost to minimize. The loss itself is computed by the forward pass and the gradient w.r.t. to the loss is computed by the backward pass. + +#### Softmax + +* LayerType: `SOFTMAX_LOSS` + +The softmax loss layer computes the multinomial logistic loss of the softmax of its inputs. It's conceptually identical to a softmax layer followed by a multinomial logistic loss layer, but provides a more numerically stable gradient. + +#### Sum-of-Squares / Euclidean + +* LayerType: `EUCLIDEAN_LOSS` + +The Euclidean loss layer computes the sum of squares of differences of its two inputs, $$\frac 1 {2N} \sum_{i=1}^N \| x^1_i - x^2_i \|_2^2$$. + +#### Hinge / Margin + +* LayerType: `HINGE_LOSS` +* CPU implementation: `./src/caffe/layers/hinge_loss_layer.cpp` +* CUDA GPU implementation: none yet +* Parameters (`HingeLossParameter hinge_loss_param`) + - Optional + - `norm` [default L1]: the norm used. Currently L1, L2 +* Inputs + - `n * c * h * w` Predictions + - `n * 1 * 1 * 1` Labels +* Output + - `1 * 1 * 1 * 1` Computed Loss +* Samples + + # L1 Norm + layers { + name: "loss" + type: HINGE_LOSS + bottom: "pred" + bottom: "label" + } + + # L2 Norm + layers { + name: "loss" + type: HINGE_LOSS + bottom: "pred" + bottom: "label" + top: "loss" + hinge_loss_param { + norm: L2 + } + } + +The hinge loss layer computes a one-vs-all hinge or squared hinge loss. + +#### Sigmoid Cross-Entropy + +`SIGMOID_CROSS_ENTROPY_LOSS` + +#### Infogain + +`INFOGAIN_LOSS` + +#### Accuracy and Top-k + +`ACCURACY` scores the output as the accuracy of output with respect to target -- it is not actually a loss and has no backward step. + +### Activation / Neuron Layers + +In general, activation / Neuron layers are element-wise operators, taking one bottom blob and producing one top blob of the same size. In the layers below, we will ignore the input and out sizes as they are identical: + +* Input + - n * c * h * w +* Output + - n * c * h * w + +#### ReLU / Rectified-Linear and Leaky-ReLU + +* LayerType: `RELU` +* CPU implementation: `./src/caffe/layers/relu_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/relu_layer.cu` +* Parameters (`ReLUParameter relu_param`) + - Optional + - `negative_slope` [default 0]: specifies whether to leak the negative part by multiplying it with the slope value rather than setting it to 0. +* Sample (as seen in `./examples/imagenet/imagenet_train_val.prototxt`) + + layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" + } + +Given an input value x, The `RELU` layer computes the output as x if x > 0 and negative_slope * x if x <= 0. When the negative slope parameter is not set, it is equivalent to the standard ReLU function of taking max(x, 0). It also supports in-place computation, meaning that the bottom and the top blob could be the same to preserve memory consumption. + +#### Sigmoid + +* LayerType: `SIGMOID` +* CPU implementation: `./src/caffe/layers/sigmoid_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/sigmoid_layer.cu` +* Sample (as seen in `./examples/imagenet/mnist_autoencoder.prototxt`) + + layers { + name: "encode1neuron" + bottom: "encode1" + top: "encode1neuron" + type: SIGMOID + } + +The `SIGMOID` layer computes the output as sigmoid(x) for each input element x. + +#### TanH / Hyperbolic Tangent + +* LayerType: `TANH` +* CPU implementation: `./src/caffe/layers/tanh_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/tanh_layer.cu` +* Sample + + layers { + name: "layer" + bottom: "in" + top: "out" + type: TANH + } + +The `TANH` layer computes the output as tanh(x) for each input element x. + +#### Absolute Value + +* LayerType: `ABSVAL` +* CPU implementation: `./src/caffe/layers/absval_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/absval_layer.cu` +* Sample + + layers { + name: "layer" + bottom: "in" + top: "out" + type: ABSVAL + } + +The `ABSVAL` layer computes the output as abs(x) for each input element x. + +#### Power + +* LayerType: `POWER` +* CPU implementation: `./src/caffe/layers/power_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/power_layer.cu` +* Parameters (`PowerParameter power_param`) + - Optional + - `power` [default 1] + - `scale` [default 1] + - `shift` [default 0] +* Sample + + layers { + name: "layer" + bottom: "in" + top: "out" + type: POWER + power_param { + power: 1 + scale: 1 + shift: 0 + } + } + +The `POWER` layer computes the output as (shift + scale * x) ^ power for each input element x. + +#### BNLL + +* LayerType: `BNLL` +* CPU implementation: `./src/caffe/layers/bnll_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/bnll_layer.cu` +* Sample + + layers { + name: "layer" + bottom: "in" + top: "out" + type: BNLL + } + +The `BNLL` (binomial normal log likelihood) layer computes the output as log(1 + exp(x)) for each input element x. + + +### Data Layers + +Data enters Caffe through data layers: they lie at the bottom of nets. Data can come from efficient databases (LevelDB or LMDB), directly from memory, or, when efficiency is not critical, from files on disk in HDF5 or common image formats. + +Common input preprocessing (mean subtraction, scaling, random cropping, and mirroring) is available by specifying `TransformationParameter`s. + +#### Database + +* LayerType: `DATA` +* Parameters + - Required + - `source`: the name of the directory containing the database + - `batch_size`: the number of inputs to process at one time + - Optional + - `rand_skip`: skip up to this number of inputs at the beginning; useful for asynchronous sgd + - `backend` [default `LEVELDB`]: choose whether to use a `LEVELDB` or `LMDB` + + + +#### In-Memory + +* LayerType: `MEMORY_DATA` +* Parameters + - Required + - `batch_size`, `channels`, `height`, `width`: specify the size of input chunks to read from memory + +The memory data layer reads data directly from memory, without copying it. In order to use it, one must call `MemoryDataLayer::Reset` (from C++) or `Net.set_input_arrays` (from Python) in order to specify a source of contiguous data (as 4D row major array), which is read one batch-sized chunk at a time. + +#### HDF5 Input + +* LayerType: `HDF5_DATA` +* Parameters + - Required + - `source`: the name of the file to read from + - `batch_size` + +#### HDF5 Output + +* LayerType: `HDF5_OUTPUT` +* Parameters + - Required + - `file_name`: name of file to write to + +The HDF5 output layer performs the opposite function of the other layers in this section: it writes its input blobs to disk. + +#### Images + +* LayerType: `IMAGE_DATA` +* Parameters + - Required + - `source`: name of a text file, with each line giving an image filename and label + - `batch_size`: number of images to batch together + - Optional + - `rand_skip` + - `shuffle` [default false] + - `new_height`, `new_width`: if provided, resize all images to this size + +#### Windows + +`WINDOW_DATA` + +#### Dummy + +`DUMMY_DATA` is for development and debugging. See `DummyDataParameter`. + +### Common Layers + +#### Inner Product + +* LayerType: `INNER_PRODUCT` +* CPU implementation: `./src/caffe/layers/inner_product_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/inner_product_layer.cu` +* Parameters (`InnerProductParameter inner_product_param`) + - Required + - `num_output` (`c_o`): the number of filters + - Strongly recommended + - `weight_filler` [default `type: 'constant' value: 0`] + - Optional + - `bias_filler` [default `type: 'constant' value: 0`] + - `bias_term` [default `true`]: specifies whether to learn and apply a set of additive biases to the filter outputs +* Input + - `n * c_i * h_i * w_i` +* Output + - `n * c_o * 1 * 1` +* Sample + + layers { + name: "fc8" + type: INNER_PRODUCT + blobs_lr: 1 # learning rate multiplier for the filters + blobs_lr: 2 # learning rate multiplier for the biases + weight_decay: 1 # weight decay multiplier for the filters + weight_decay: 0 # weight decay multiplier for the biases + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } + bottom: "fc7" + top: "fc8" + } + +The `INNER_PRODUCT` layer (also usually referred to as the fully connected layer) treats the input as a simple vector and produces an output in the form of a single vector (with the blob's height and width set to 1). + +#### Splitting + +The `SPLIT` layer is a utility layer that splits an input blob to multiple output blobs. This is used when a blob is fed into multiple output layers. + +#### Flattening + +The `FLATTEN` layer is a utility layer that flattens an input of shape `n * c * h * w` to a simple vector output of shape `n * (c*h*w) * 1 * 1`. + +#### Concatenation + +* LayerType: `CONCAT` +* CPU implementation: `./src/caffe/layers/concat_layer.cpp` +* CUDA GPU implementation: `./src/caffe/layers/concat_layer.cu` +* Parameters (`ConcatParameter concat_param`) + - Optional + - `concat_dim` [default 1]: 0 for concatenation along num and 1 for channels. +* Input + - `n_i * c_i * h * w` for each input blob i from 1 to K. +* Output + - if `concat_dim = 0`: `(n_1 + n_2 + ... + n_K) * c_1 * h * w`, and all input `c_i` should be the same. + - if `concat_dim = 1`: `n_1 * (c_1 + c_2 + ... + c_K) * h * w`, and all input `n_i` should be the same. +* Sample + + layers { + name: "concat" + bottom: "in1" + bottom: "in2" + top: "out" + type: CONCAT + concat_param { + concat_dim: 1 + } + } + +The `CONCAT` layer is a utility layer that concatenates its multiple input blobs to one single output blob. Currently, the layer supports concatenation along num or channels only. + +#### Slicing + +The `SLICE` layer is a utility layer that slices an input layer to multiple output layers along a given dimension (currently num or channel only) with given slice indices. + +#### Elementwise Operations + +`ELTWISE` + +#### Argmax + +`ARGMAX` + +#### Softmax + +`SOFTMAX` + +#### Mean-Variance Normalization + +`MVN` diff --git a/caffe-crfrnn/docs/tutorial/loss.md b/caffe-crfrnn/docs/tutorial/loss.md new file mode 100644 index 00000000..aac56177 --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/loss.md @@ -0,0 +1,51 @@ +--- +title: Loss +--- +# Loss + +In Caffe, as in most of machine learning, learning is driven by a **loss** function (also known as an **error**, **cost**, or **objective** function). +A loss function specifies the goal of learning by mapping parameter settings (i.e., the current network weights) to a scalar value specifying the "badness" of these parameter settings. +Hence, the goal of learning is to find a setting of the weights that *minimizes* the loss function. + +The loss in Caffe is computed by the Forward pass of the network. +Each layer takes a set of input (`bottom`) blobs and produces a set of output (`top`) blobs. +Some of these layers' outputs may be used in the loss function. +A typical choice of loss function for one-versus-all classification tasks is the `SOFTMAX_LOSS` function, used in a network definition as follows, for example: + + layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "pred" + bottom: "label" + top: "loss" + } + +In a `SOFTMAX_LOSS` function, the `top` blob is a scalar (dimensions $$1 \times 1 \times 1 \times 1$$) which averages the loss (computed from predicted labels `pred` and actuals labels `label`) over the entire mini-batch. + +### Loss weights + +For nets with multiple layers producing a loss (e.g., a network that both classifies the input using a `SOFTMAX_LOSS` layer and reconstructs it using a `EUCLIDEAN_LOSS` layer), *loss weights* can be used to specify their relative importance. + +By convention, Caffe layer types with the suffix `_LOSS` contribute to the loss function, but other layers are assumed to be purely used for intermediate computations. +However, any layer can be used as a loss by adding a field `loss_weight: ` to a layer definition for each `top` blob produced by the layer. +Layers with the suffix `_LOSS` have an implicit `loss_weight: 1` for the first `top` blob (and `loss_weight: 0` for any additional `top`s); other layers have an implicit `loss_weight: 0` for all `top`s. +So, the above `SOFTMAX_LOSS` layer could be equivalently written as: + + layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "pred" + bottom: "label" + top: "loss" + loss_weight: 1 + } + +However, *any* layer able to backpropagate may be given a non-zero `loss_weight`, allowing one to, for example, regularize the activations produced by some intermediate layer(s) of the network if desired. +For non-singleton outputs with an associated non-zero loss, the loss is computed simply by summing over all entries of the blob. + +The final loss in Caffe, then, is computed by summing the total weighted loss over the network, as in the following pseudo-code: + + loss := 0 + for layer in layers: + for top, loss_weight in layer.tops, layer.loss_weights: + loss += loss_weight * sum(top) diff --git a/caffe-crfrnn/docs/tutorial/net_layer_blob.md b/caffe-crfrnn/docs/tutorial/net_layer_blob.md new file mode 100644 index 00000000..1f0966f8 --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/net_layer_blob.md @@ -0,0 +1,170 @@ +--- +title: Blobs, Layers, and Nets +--- +# Blobs, Layers, and Nets: anatomy of a Caffe model + +Deep networks are compositional models that are naturally represented as a collection of inter-connected layers that work on chunks of data. Caffe defines a net layer-by-layer in its own model schema. The network defines the entire model bottom-to-top from input data to loss. As data and derivatives flow through the network in the [forward and backward passes](forward_backward.html) Caffe stores, communicates, and manipulates the information as *blobs*: the blob is the standard array and unified memory interface for the framework. The layer comes next as the foundation of both model and computation. The net follows as the collection and connection of layers. The details of blob describe how information is stored and communicated in and across layers and nets. + +[Solving](solver.html) is configured separately to decouple modeling and optimization. + +We will go over the details of these components in more detail. + +## Blob storage and communication + +A Blob is a wrapper over the actual data being processed and passed along by Caffe, and also under the hood provides synchronization capability between the CPU and the GPU. Mathematically, a blob is a 4-dimensional array that stores things in the order of (Num, Channels, Height and Width), from major to minor, and stored in a C-contiguous fashion. The main reason for putting Num (the name is due to legacy reasons, and is equivalent to the notation of "batch" as in minibatch SGD). + +Caffe stores and communicates data in 4-dimensional arrays called blobs. Blobs provide a unified memory interface, holding data e.g. batches of images, model parameters, and derivatives for optimization. + +Blobs conceal the computational and mental overhead of mixed CPU/GPU operation by synchronizing from the CPU host to the GPU device as needed. Memory on the host and device is allocated on demand (lazily) for efficient memory usage. + +The conventional blob dimensions for data are number N x channel K x height H x width W. Blob memory is row-major in layout so the last / rightmost dimension changes fastest. For example, the value at index (n, k, h, w) is physically located at index ((n * K + k) * H + h) * W + w. + +- Number / N is the batch size of the data. Batch processing achieves better throughput for communication and device processing. For an ImageNet training batch of 256 images B = 256. +- Channel / K is the feature dimension e.g. for RGB images K = 3. + +Note that although we have designed blobs with its dimensions corresponding to image applications, they are named purely for notational purpose and it is totally valid for you to do non-image applications. For example, if you simply need fully-connected networks like the conventional multi-layer perceptron, use blobs of dimensions (Num, Channels, 1, 1) and call the InnerProductLayer (which we will cover soon). + +Caffe operations are general with respect to the channel dimension / K. Grayscale and hyperspectral imagery are fine. Caffe can likewise model and process arbitrary vectors in blobs with singleton. That is, the shape of blob holding 1000 vectors of 16 feature dimensions is 1000 x 16 x 1 x 1. + +Parameter blob dimensions vary according to the type and configuration of the layer. For a convolution layer with 96 filters of 11 x 11 spatial dimension and 3 inputs the blob is 96 x 3 x 11 x 11. For an inner product / fully-connected layer with 1000 output channels and 1024 input channels the parameter blob is 1 x 1 x 1000 x 1024. + +For custom data it may be necessary to hack your own input preparation tool or data layer. However once your data is in your job is done. The modularity of layers accomplishes the rest of the work for you. + +### Implementation Details + +As we are often interested in the values as well as the gradients of the blob, a Blob stores two chunks of memories, *data* and *diff*. The former is the normal data that we pass along, and the latter is the gradient computed by the network. + +Further, as the actual values could be stored either on the CPU and on the GPU, there are two different ways to access them: the const way, which does not change the values, and the mutable way, which changes the values: + + const Dtype* cpu_data() const; + Dtype* mutable_cpu_data(); + +(similarly for gpu and diff). + +The reason for such design is that, a Blob uses a SyncedMem class to synchronize values between the CPU and GPU in order to hide the synchronization details and to minimize data transfer. A rule of thumb is, always use the const call if you do not want to change the values, and never store the pointers in your own object. Every time you work on a blob, call the functions to get the pointers, as the SyncedMem will need this to figure out when to copy data. + +In practice when GPUs are present, one loads data from the disk to a blob in CPU code, calls a device kernel to do GPU computation, and ferries the blob off to the next layer, ignoring low-level details while maintaining a high level of performance. As long as all layers have GPU implementations, all the intermediate data and gradients will remain in the GPU. + +If you want to check out when a Blob will copy data, here is an illustrative example: + + // Assuming that data are on the CPU initially, and we have a blob. + const Dtype* foo; + Dtype* bar; + foo = blob.gpu_data(); // data copied cpu->gpu. + foo = blob.cpu_data(); // no data copied since both have up-to-date contents. + bar = blob.mutable_gpu_data(); // no data copied. + // ... some operations ... + bar = blob.mutable_gpu_data(); // no data copied when we are still on GPU. + foo = blob.cpu_data(); // data copied gpu->cpu, since the gpu side has modified the data + foo = blob.gpu_data(); // no data copied since both have up-to-date contents + bar = blob.mutable_cpu_data(); // still no data copied. + bar = blob.mutable_gpu_data(); // data copied cpu->gpu. + bar = blob.mutable_cpu_data(); // data copied gpu->cpu. + +## Layer computation and connections + +The layer is the essence of a model and the fundamental unit of computation. Layers convolve filters, pool, take inner products, apply nonlinearities like rectified-linear and sigmoid and other elementwise transformations, normalize, load data, and compute losses like softmax and hinge. [See the layer catalogue](layers.html) for all operations. Most of the types needed for state-of-the-art deep learning tasks are there. + +A layer with bottom and top blob. + +A layer takes input through *bottom* connections and makes output through *top* connections. + +Each layer type defines three critical computations: *setup*, *forward*, and *backward*. + +- Setup: initialize the layer and its connections once at model initialization. +- Forward: given input from bottom compute the output and send to the top. +- Backward: given the gradient w.r.t. the top output compute the gradient w.r.t. to the input and send to the bottom. A layer with parameters computes the gradient w.r.t. to its parameters and stores it internally. + +More specifically, there will be two Forward and Backward functions implemented, one for CPU and one for GPU. If you do not implement a GPU version, the layer will fall back to the CPU functions as a backup option. This may come handy if you would like to do quick experiments, although it may come with additional data transfer cost (its inputs will be copied from GPU to CPU, and its outputs will be copied back from CPU to GPU). + +Layers have two key responsibilities for the operation of the network as a whole: a *forward pass* that takes the inputs and produces the outputs, and a *backward pass* that takes the gradient with respect to the output, and computes the gradients with respect to the parameters and to the inputs, which are in turn back-propagated to earlier layers. These passes are simply the composition of each layer's forward and backward. + +Developing custom layers requires minimal effort by the compositionality of the network and modularity of the code. Define the setup, forward, and backward for the layer and it is ready for inclusion in a net. + +## Net definition and operation + +The net jointly defines a function and its gradient by composition and auto-differentiation. The composition of every layer's output computes the function to do a given task, and the composition of every layer's backward computes the gradient from the loss to learn the task. Caffe models are end-to-end machine learning engines. + +The net is a set of layers connected in a computation graph -- a directed acyclic graph (DAG) to be exact. Caffe does all the bookkeeping for any DAG of layers to ensure correctness of the forward and backward passes. A typical net begins with a data layer that loads from disk and ends with a loss layer that computes the objective for a task such as classification or reconstruction. + +The net is defined as a set of layers and their connections in a plaintext modeling language. +A simple logistic regression classifier + +Softmax Regression + +is defined by + + name: "LogReg" + layers { + name: "mnist" + type: DATA + top: "data" + top: "label" + data_param { + source: "input_leveldb" + batch_size: 64 + } + } + layers { + name: "ip" + type: INNER_PRODUCT + bottom: "data" + top: "ip" + inner_product_param { + num_output: 2 + } + } + layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "ip" + bottom: "label" + top: "loss" + } + +Model initialization is handled by `Net::Init()`. The initialization mainly does two things: scaffolding the overall DAG by creating the blobs and layers (for C++ geeks: the network will retain ownership of the blobs and layers during its lifetime), and calls the layers' `SetUp()` function. It also does a set of other bookkeeping things, such as validating the correctness of the overall network architecture. Also, during initialization the Net explains its initialization by logging to INFO as it goes: + + I0902 22:52:17.931977 2079114000 net.cpp:39] Initializing net from parameters: + name: "LogReg" + [...model prototxt printout...] + # construct the network layer-by-layer + I0902 22:52:17.932152 2079114000 net.cpp:67] Creating Layer mnist + I0902 22:52:17.932165 2079114000 net.cpp:356] mnist -> data + I0902 22:52:17.932188 2079114000 net.cpp:356] mnist -> label + I0902 22:52:17.932200 2079114000 net.cpp:96] Setting up mnist + I0902 22:52:17.935807 2079114000 data_layer.cpp:135] Opening leveldb input_leveldb + I0902 22:52:17.937155 2079114000 data_layer.cpp:195] output data size: 64,1,28,28 + I0902 22:52:17.938570 2079114000 net.cpp:103] Top shape: 64 1 28 28 (50176) + I0902 22:52:17.938593 2079114000 net.cpp:103] Top shape: 64 1 1 1 (64) + I0902 22:52:17.938611 2079114000 net.cpp:67] Creating Layer ip + I0902 22:52:17.938617 2079114000 net.cpp:394] ip <- data + I0902 22:52:17.939177 2079114000 net.cpp:356] ip -> ip + I0902 22:52:17.939196 2079114000 net.cpp:96] Setting up ip + I0902 22:52:17.940289 2079114000 net.cpp:103] Top shape: 64 2 1 1 (128) + I0902 22:52:17.941270 2079114000 net.cpp:67] Creating Layer loss + I0902 22:52:17.941305 2079114000 net.cpp:394] loss <- ip + I0902 22:52:17.941314 2079114000 net.cpp:394] loss <- label + I0902 22:52:17.941323 2079114000 net.cpp:356] loss -> loss + # set up the loss and configure the backward pass + I0902 22:52:17.941328 2079114000 net.cpp:96] Setting up loss + I0902 22:52:17.941328 2079114000 net.cpp:103] Top shape: 1 1 1 1 (1) + I0902 22:52:17.941329 2079114000 net.cpp:109] with loss weight 1 + I0902 22:52:17.941779 2079114000 net.cpp:170] loss needs backward computation. + I0902 22:52:17.941787 2079114000 net.cpp:170] ip needs backward computation. + I0902 22:52:17.941794 2079114000 net.cpp:172] mnist does not need backward computation. + # determine outputs + I0902 22:52:17.941800 2079114000 net.cpp:208] This network produces output loss + # finish initialization and report memory usage + I0902 22:52:17.941810 2079114000 net.cpp:467] Collecting Learning Rate and Weight Decay. + I0902 22:52:17.941818 2079114000 net.cpp:219] Network initialization done. + I0902 22:52:17.941824 2079114000 net.cpp:220] Memory required for data: 201476 + +Note that the construction of the network is device agnostic - recall our earlier explanation that blobs and layers hide implementation details from the model definition. After construction, the network is run on either CPU or GPU by setting a single switch defined in `Caffe::mode()` and set by `Caffe::set_mode()`. Layers come with corresponding CPU and GPU routines that produce identical results (up to numerical errors, and with tests to guard it). The CPU / GPU switch is seamless and independent of the model definition. For research and deployment alike it is best to divide model and implementation. + +### Model format + +The models are defined in plaintext protocol buffer schema (prototxt) while the learned models are serialized as binary protocol buffer (binaryproto) .caffemodel files. + +The model format is defined by the protobuf schema in [caffe.proto](https://github.com/BVLC/caffe/blob/master/src/caffe/proto/caffe.proto). The source file is mostly self-explanatory so one is encouraged to check it out. + +Caffe speaks [Google Protocol Buffer](https://code.google.com/p/protobuf/) for the following strengths: minimal-size binary strings when serialized, efficient serialization, a human-readable text format compatible with the binary version, and efficient interface implementations in multiple languages, most notably C++ and Python. This all contributes to the flexibility and extensibility of modeling in Caffe. diff --git a/caffe-crfrnn/docs/tutorial/solver.md b/caffe-crfrnn/docs/tutorial/solver.md new file mode 100644 index 00000000..8884ea0e --- /dev/null +++ b/caffe-crfrnn/docs/tutorial/solver.md @@ -0,0 +1,271 @@ +--- +title: Solver / Model Optimization +--- +# Solver + +The solver orchestrates model optimization by coordinating the network's forward inference and backward gradients to form parameter updates that attempt to improve the loss. +The responsibilities of learning are divided between the Solver for overseeing the optimization and generating parameter updates and the Net for yielding loss and gradients. + +The Caffe solvers are Stochastic Gradient Descent (SGD), Adaptive Gradient (ADAGRAD), and Nesterov's Accelerated Gradient (NAG). + +The solver + +1. scaffolds the optimization bookkeeping and creates the training network for learning and test network(s) for evaluation. +2. iteratively optimizes by calling forward / backward and updating parameters +3. (periodically) evaluates the test networks +4. snapshots the model and solver state throughout the optimization + +where each iteration + +1. calls network forward to compute the output and loss +2. calls network backward to compute the gradients +3. incorporates the gradients into parameter updates according to the solver method +4. updates the solver state according to learning rate, history, and method + +to take the weights all the way from initialization to learned model. + +Like Caffe models, Caffe solvers run in CPU / GPU modes. + +## Methods + +The solver methods address the general optimization problem of loss minimization. +For dataset $$D$$, the optimization objective is the average loss over all $$|D|$$ data instances throughout the dataset + +$$L(W) = \frac{1}{|D|} \sum_i^{|D|} f_W\left(X^{(i)}\right) + \lambda r(W)$$ + +where $$f_W\left(X^{(i)}\right)$$ is the loss on data instance $$X^{(i)}$$ and $$r(W)$$ is a regularization term with weight $$\lambda$$. +$$|D|$$ can be very large, so in practice, in each solver iteration we use a stochastic approximation of this objective, drawing a mini-batch of $$N << |D|$$ instances: + +$$L(W) \approx \frac{1}{N} \sum_i^N f_W\left(X^{(i)}\right) + \lambda r(W)$$ + +The model computes $$f_W$$ in the forward pass and the gradient $$\nabla f_W$$ in the backward pass. + +The parameter update $$\Delta W$$ is formed by the solver from the error gradient $$\nabla f_W$$, the regularization gradient $$\nabla r(W)$$, and other particulars to each method. + +### SGD + +**Stochastic gradient descent** (`solver_type: SGD`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$. +The **learning rate** $$ \alpha $$ is the weight of the negative gradient. +The **momentum** $$ \mu $$ is the weight of the previous update. + +Formally, we have the following formulas to compute the update value $$ V_{t+1} $$ and the updated weights $$ W_{t+1} $$ at iteration $$ t+1 $$, given the previous weight update $$ V_t $$ and current weights $$ W_t $$: + +$$ +V_{t+1} = \mu V_t - \alpha \nabla L(W_t) +$$ + +$$ +W_{t+1} = W_t + V_{t+1} +$$ + +The learning "hyperparameters" ($$\alpha$$ and $$\mu$$) might require a bit of tuning for best results. +If you're not sure where to start, take a look at the "Rules of thumb" below, and for further information you might refer to Leon Bottou's [Stochastic Gradient Descent Tricks](http://research.microsoft.com/pubs/192769/tricks-2012.pdf) [1]. + +[1] L. Bottou. + [Stochastic Gradient Descent Tricks](http://research.microsoft.com/pubs/192769/tricks-2012.pdf). + *Neural Networks: Tricks of the Trade*: Springer, 2012. + +#### Rules of thumb for setting the learning rate $$ \alpha $$ and momentum $$ \mu $$ + +A good strategy for deep learning with SGD is to initialize the learning rate $$ \alpha $$ to a value around $$ \alpha \approx 0.01 = 10^{-2} $$, and dropping it by a constant factor (e.g., 10) throughout training when the loss begins to reach an apparent "plateau", repeating this several times. +Generally, you probably want to use a momentum $$ \mu = 0.9 $$ or similar value. +By smoothing the weight updates across iterations, momentum tends to make deep learning with SGD both stabler and faster. + +This was the strategy used by Krizhevsky et al. [1] in their famously winning CNN entry to the ILSVRC-2012 competition, and Caffe makes this strategy easy to implement in a `SolverParameter`, as in our reproduction of [1] at `./examples/imagenet/alexnet_solver.prototxt`. + +To use a learning rate policy like this, you can put the following lines somewhere in your solver prototxt file: + + base_lr: 0.01 # begin training at a learning rate of 0.01 = 1e-2 + + lr_policy: "step" # learning rate policy: drop the learning rate in "steps" + # by a factor of gamma every stepsize iterations + + gamma: 0.1 # drop the learning rate by a factor of 10 + # (i.e., multiply it by a factor of gamma = 0.1) + + stepsize: 100000 # drop the learning rate every 100K iterations + + max_iter: 350000 # train for 350K iterations total + + momentum: 0.9 + +Under the above settings, we'll always use `momentum` $$ \mu = 0.9 $$. +We'll begin training at a `base_lr` of $$ \alpha = 0.01 = 10^{-2} $$ for the first 100,000 iterations, then multiply the learning rate by `gamma` ($$ \gamma $$) and train at $$ \alpha' = \alpha \gamma = (0.01) (0.1) = 0.001 = 10^{-3} $$ for iterations 100K-200K, then at $$ \alpha'' = 10^{-4} $$ for iterations 200K-300K, and finally train until iteration 350K (since we have `max_iter: 350000`) at $$ \alpha''' = 10^{-5} $$. + +Note that the momentum setting $$ \mu $$ effectively multiplies the size of your updates by a factor of $$ \frac{1}{1 - \mu} $$ after many iterations of training, so if you increase $$ \mu $$, it may be a good idea to **decrease** $$ \alpha $$ accordingly (and vice versa). + +For example, with $$ \mu = 0.9 $$, we have an effective update size multiplier of $$ \frac{1}{1 - 0.9} = 10 $$. +If we increased the momentum to $$ \mu = 0.99 $$, we've increased our update size multiplier to 100, so we should drop $$ \alpha $$ (`base_lr`) by a factor of 10. + +Note also that the above settings are merely guidelines, and they're definitely not guaranteed to be optimal (or even work at all!) in every situation. +If learning diverges (e.g., you start to see very large or `NaN` or `inf` loss values or outputs), try dropping the `base_lr` (e.g., `base_lr: 0.001`) and re-training, repeating this until you find a `base_lr` value that works. + +[1] A. Krizhevsky, I. Sutskever, and G. Hinton. + [ImageNet Classification with Deep Convolutional Neural Networks](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). + *Advances in Neural Information Processing Systems*, 2012. + +### AdaGrad + +The **adaptive gradient** (`solver_type: ADAGRAD`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words. +Given the update information from all previous iterations $$ \left( \nabla L(W) \right)_{t'} $$ for $$ t' \in \{1, 2, ..., t\} $$, +the update formulas proposed by [1] are as follows, specified for each component $$i$$ of the weights $$W$$: + +$$ +(W_{t+1})_i = +(W_t)_i - \alpha +\frac{\left( \nabla L(W_t) \right)_{i}}{ + \sqrt{\sum_{t'=1}^{t} \left( \nabla L(W_{t'}) \right)_i^2} +} +$$ + +Note that in practice, for weights $$ W \in \mathcal{R}^d $$, AdaGrad implementations (including the one in Caffe) use only $$ \mathcal{O}(d) $$ extra storage for the historical gradient information (rather than the $$ \mathcal{O}(dt) $$ storage that would be necessary to store each historical gradient individually). + +[1] J. Duchi, E. Hazan, and Y. Singer. + [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.magicbroom.info/Papers/DuchiHaSi10.pdf). + *The Journal of Machine Learning Research*, 2011. + +### NAG + +**Nesterov's accelerated gradient** (`solver_type: NAG`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$. +Though the required assumptions to achieve the $$ \mathcal{O}(1/t^2) $$ convergence typically will not hold for deep networks trained with Caffe (e.g., due to non-smoothness and non-convexity), in practice NAG can be a very effective method for optimizing certain types of deep learning architectures, as demonstrated for deep MNIST autoencoders by Sutskever et al. [2]. + +The weight update formulas look very similar to the SGD updates given above: + +$$ +V_{t+1} = \mu V_t - \alpha \nabla L(W_t + \mu V_t) +$$ + +$$ +W_{t+1} = W_t + V_{t+1} +$$ + +What distinguishes the method from SGD is the weight setting $$ W $$ on which we compute the error gradient $$ \nabla L(W) $$ -- in NAG we take the gradient on weights with added momentum $$ \nabla L(W_t + \mu V_t) $$; in SGD we simply take the gradient $$ \nabla L(W_t) $$ on the current weights themselves. + +[1] Y. Nesterov. + A Method of Solving a Convex Programming Problem with Convergence Rate $$\mathcal{O}(1/\sqrt{k})$$. + *Soviet Mathematics Doklady*, 1983. + +[2] I. Sutskever, J. Martens, G. Dahl, and G. Hinton. + [On the Importance of Initialization and Momentum in Deep Learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf). + *Proceedings of the 30th International Conference on Machine Learning*, 2013. + +## Scaffolding + +The solver scaffolding prepares the optimization method and initializes the model to be learned in `Solver::Presolve()`. + + > caffe train -solver examples/mnist/lenet_solver.prototxt + I0902 13:35:56.474978 16020 caffe.cpp:90] Starting Optimization + I0902 13:35:56.475190 16020 solver.cpp:32] Initializing solver from parameters: + test_iter: 100 + test_interval: 500 + base_lr: 0.01 + display: 100 + max_iter: 10000 + lr_policy: "inv" + gamma: 0.0001 + power: 0.75 + momentum: 0.9 + weight_decay: 0.0005 + snapshot: 5000 + snapshot_prefix: "examples/mnist/lenet" + solver_mode: GPU + net: "examples/mnist/lenet_train_test.prototxt" + +Net initialization + + I0902 13:35:56.655681 16020 solver.cpp:72] Creating training net from net file: examples/mnist/lenet_train_test.prototxt + [...] + I0902 13:35:56.656740 16020 net.cpp:56] Memory required for data: 0 + I0902 13:35:56.656791 16020 net.cpp:67] Creating Layer mnist + I0902 13:35:56.656811 16020 net.cpp:356] mnist -> data + I0902 13:35:56.656846 16020 net.cpp:356] mnist -> label + I0902 13:35:56.656874 16020 net.cpp:96] Setting up mnist + I0902 13:35:56.694052 16020 data_layer.cpp:135] Opening lmdb examples/mnist/mnist_train_lmdb + I0902 13:35:56.701062 16020 data_layer.cpp:195] output data size: 64,1,28,28 + I0902 13:35:56.701146 16020 data_layer.cpp:236] Initializing prefetch + I0902 13:35:56.701196 16020 data_layer.cpp:238] Prefetch initialized. + I0902 13:35:56.701212 16020 net.cpp:103] Top shape: 64 1 28 28 (50176) + I0902 13:35:56.701230 16020 net.cpp:103] Top shape: 64 1 1 1 (64) + [...] + I0902 13:35:56.703737 16020 net.cpp:67] Creating Layer ip1 + I0902 13:35:56.703753 16020 net.cpp:394] ip1 <- pool2 + I0902 13:35:56.703778 16020 net.cpp:356] ip1 -> ip1 + I0902 13:35:56.703797 16020 net.cpp:96] Setting up ip1 + I0902 13:35:56.728127 16020 net.cpp:103] Top shape: 64 500 1 1 (32000) + I0902 13:35:56.728142 16020 net.cpp:113] Memory required for data: 5039360 + I0902 13:35:56.728175 16020 net.cpp:67] Creating Layer relu1 + I0902 13:35:56.728194 16020 net.cpp:394] relu1 <- ip1 + I0902 13:35:56.728219 16020 net.cpp:345] relu1 -> ip1 (in-place) + I0902 13:35:56.728240 16020 net.cpp:96] Setting up relu1 + I0902 13:35:56.728256 16020 net.cpp:103] Top shape: 64 500 1 1 (32000) + I0902 13:35:56.728270 16020 net.cpp:113] Memory required for data: 5167360 + I0902 13:35:56.728287 16020 net.cpp:67] Creating Layer ip2 + I0902 13:35:56.728304 16020 net.cpp:394] ip2 <- ip1 + I0902 13:35:56.728333 16020 net.cpp:356] ip2 -> ip2 + I0902 13:35:56.728356 16020 net.cpp:96] Setting up ip2 + I0902 13:35:56.728690 16020 net.cpp:103] Top shape: 64 10 1 1 (640) + I0902 13:35:56.728705 16020 net.cpp:113] Memory required for data: 5169920 + I0902 13:35:56.728734 16020 net.cpp:67] Creating Layer loss + I0902 13:35:56.728747 16020 net.cpp:394] loss <- ip2 + I0902 13:35:56.728767 16020 net.cpp:394] loss <- label + I0902 13:35:56.728786 16020 net.cpp:356] loss -> loss + I0902 13:35:56.728811 16020 net.cpp:96] Setting up loss + I0902 13:35:56.728837 16020 net.cpp:103] Top shape: 1 1 1 1 (1) + I0902 13:35:56.728849 16020 net.cpp:109] with loss weight 1 + I0902 13:35:56.728878 16020 net.cpp:113] Memory required for data: 5169924 + +Loss + + I0902 13:35:56.728893 16020 net.cpp:170] loss needs backward computation. + I0902 13:35:56.728909 16020 net.cpp:170] ip2 needs backward computation. + I0902 13:35:56.728924 16020 net.cpp:170] relu1 needs backward computation. + I0902 13:35:56.728938 16020 net.cpp:170] ip1 needs backward computation. + I0902 13:35:56.728953 16020 net.cpp:170] pool2 needs backward computation. + I0902 13:35:56.728970 16020 net.cpp:170] conv2 needs backward computation. + I0902 13:35:56.728984 16020 net.cpp:170] pool1 needs backward computation. + I0902 13:35:56.728998 16020 net.cpp:170] conv1 needs backward computation. + I0902 13:35:56.729014 16020 net.cpp:172] mnist does not need backward computation. + I0902 13:35:56.729027 16020 net.cpp:208] This network produces output loss + I0902 13:35:56.729053 16020 net.cpp:467] Collecting Learning Rate and Weight Decay. + I0902 13:35:56.729071 16020 net.cpp:219] Network initialization done. + I0902 13:35:56.729085 16020 net.cpp:220] Memory required for data: 5169924 + I0902 13:35:56.729277 16020 solver.cpp:156] Creating test net (#0) specified by net file: examples/mnist/lenet_train_test.prototxt + +Completion + + I0902 13:35:56.806970 16020 solver.cpp:46] Solver scaffolding done. + I0902 13:35:56.806984 16020 solver.cpp:165] Solving LeNet + + +## Updating Parameters + +The actual weight update is made by the solver then applied to the net parameters in `Solver::ComputeUpdateValue()`. +The `ComputeUpdateValue` method incorporates any weight decay $$ r(W) $$ into the weight gradients (which currently just contain the error gradients) to get the final gradient with respect to each network weight. +Then these gradients are scaled by the learning rate $$ \alpha $$ and the update to subtract is stored in each parameter Blob's `diff` field. +Finally, the `Blob::Update` method is called on each parameter blob, which performs the final update (subtracting the Blob's `diff` from its `data`). + +## Snapshotting and Resuming + +The solver snapshots the weights and its own state during training in `Solver::Snapshot()` and `Solver::SnapshotSolverState()`. +The weight snapshots export the learned model while the solver snapshots allow training to be resumed from a given point. +Training is resumed by `Solver::Restore()` and `Solver::RestoreSolverState()`. + +Weights are saved without extension while solver states are saved with `.solverstate` extension. +Both files will have an `_iter_N` suffix for the snapshot iteration number. + +Snapshotting is configured by: + + # The snapshot interval in iterations. + snapshot: 5000 + # File path prefix for snapshotting model weights and solver state. + # Note: this is relative to the invocation of the `caffe` utility, not the + # solver definition file. + snapshot_prefix: "/path/to/model" + # Snapshot the diff along with the weights. This can help debugging training + # but takes more storage. + snapshot_diff: false + # A final snapshot is saved at the end of training unless + # this flag is set to false. The default is true. + snapshot_after_train: true + +in the solver definition prototxt. diff --git a/caffe-crfrnn/examples/.gitignore b/caffe-crfrnn/examples/.gitignore new file mode 100644 index 00000000..29aa4e63 --- /dev/null +++ b/caffe-crfrnn/examples/.gitignore @@ -0,0 +1,2 @@ +*/*.caffemodel +*/*.solverstate diff --git a/caffe-crfrnn/examples/CMakeLists.txt b/caffe-crfrnn/examples/CMakeLists.txt new file mode 100644 index 00000000..663d7360 --- /dev/null +++ b/caffe-crfrnn/examples/CMakeLists.txt @@ -0,0 +1,31 @@ +file(GLOB_RECURSE examples_srcs "${PROJECT_SOURCE_DIR}/examples/*.cpp") + +foreach(source_file ${examples_srcs}) + # get file name + get_filename_component(name ${source_file} NAME_WE) + + # get folder name + get_filename_component(path ${source_file} PATH) + get_filename_component(folder ${path} NAME_WE) + + add_executable(${name} ${source_file}) + target_link_libraries(${name} ${Caffe_LINK}) + caffe_default_properties(${name}) + + # set back RUNTIME_OUTPUT_DIRECTORY + set_target_properties(${name} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/examples/${folder}") + + caffe_set_solution_folder(${name} examples) + + # install + install(TARGETS ${name} DESTINATION bin) + + if(UNIX OR APPLE) + # Funny command to make tutorials work + # TODO: remove in future as soon as naming is standartaized everywhere + set(__outname ${PROJECT_BINARY_DIR}/examples/${folder}/${name}${Caffe_POSTFIX}) + add_custom_command(TARGET ${name} POST_BUILD + COMMAND ln -sf "${__outname}" "${__outname}.bin") + endif() +endforeach() diff --git a/caffe-crfrnn/include/caffe/blob.hpp b/caffe-crfrnn/include/caffe/blob.hpp new file mode 100644 index 00000000..ef10aea5 --- /dev/null +++ b/caffe-crfrnn/include/caffe/blob.hpp @@ -0,0 +1,144 @@ +#ifndef CAFFE_BLOB_HPP_ +#define CAFFE_BLOB_HPP_ + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +/** + * @brief A wrapper around SyncedMemory holders serving as the basic + * computational unit through which Layer%s, Net%s, and Solver%s + * interact. + * + * TODO(dox): more thorough description. + */ +template +class Blob { + public: + Blob() + : data_(), diff_(), num_(0), channels_(0), height_(0), width_(0), + count_(0), capacity_(0) {} + explicit Blob(const int num, const int channels, const int height, + const int width); + /** + * @brief Change the dimensions of the blob, allocating new memory if + * necessary. + * + * This function can be called both to create an initial allocation + * of memory, and to adjust the dimensions of a top blob during Layer::Reshape + * or Layer::Forward. When changing the size of blob, memory will only be + * reallocated if sufficient memory does not already exist, and excess memory + * will never be freed. + * + * Note that reshaping an input blob and immediately calling Net::Backward is + * an error; either Net::Forward or Net::Reshape need to be called to + * propagate the new input shape to higher layers. + */ + void Reshape(const int num, const int channels, const int height, + const int width); + void ReshapeLike(const Blob& other); + inline int num() const { return num_; } + inline int channels() const { return channels_; } + inline int height() const { return height_; } + inline int width() const { return width_; } + inline int count() const { return count_; } + inline int offset(const int n, const int c = 0, const int h = 0, + const int w = 0) const { + CHECK_GE(n, 0); + CHECK_LE(n, num_); + CHECK_GE(channels_, 0); + CHECK_LE(c, channels_); + CHECK_GE(height_, 0); + CHECK_LE(h, height_); + CHECK_GE(width_, 0); + CHECK_LE(w, width_); + return ((n * channels_ + c) * height_ + h) * width_ + w; + } + /** + * @brief Copy from a source Blob. + * + * @param source the Blob to copy from + * @param copy_diff if false, copy the data; if true, copy the diff + * @param reshape if false, require this Blob to be pre-shaped to the shape + * of other (and die otherwise); if true, Reshape this Blob to other's + * shape if necessary + */ + void CopyFrom(const Blob& source, bool copy_diff = false, + bool reshape = false); + + inline Dtype data_at(const int n, const int c, const int h, + const int w) const { + return *(cpu_data() + offset(n, c, h, w)); + } + + inline Dtype diff_at(const int n, const int c, const int h, + const int w) const { + return *(cpu_diff() + offset(n, c, h, w)); + } + + inline const shared_ptr& data() const { + CHECK(data_); + return data_; + } + + inline const shared_ptr& diff() const { + CHECK(diff_); + return diff_; + } + + const Dtype* cpu_data() const; + void set_cpu_data(Dtype* data); + const Dtype* gpu_data() const; + const Dtype* cpu_diff() const; + const Dtype* gpu_diff() const; + Dtype* mutable_cpu_data(); + Dtype* mutable_gpu_data(); + Dtype* mutable_cpu_diff(); + Dtype* mutable_gpu_diff(); + void Update(); + void FromProto(const BlobProto& proto); + void ToProto(BlobProto* proto, bool write_diff = false) const; + + /// @brief Compute the sum of absolute values (L1 norm) of the data. + Dtype asum_data() const; + /// @brief Compute the sum of absolute values (L1 norm) of the diff. + Dtype asum_diff() const; + + /** + * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the + * data_ of Blob other -- useful in Layer&s which simply perform a copy + * in their Forward pass. + * + * This deallocates the SyncedMemory holding this Blob's data_, as + * shared_ptr calls its destructor when reset with the "=" operator. + */ + void ShareData(const Blob& other); + /** + * @brief Set the diff_ shared_ptr to point to the SyncedMemory holding the + * diff_ of Blob other -- useful in Layer&s which simply perform a copy + * in their Forward pass. + * + * This deallocates the SyncedMemory holding this Blob's diff_, as + * shared_ptr calls its destructor when reset with the "=" operator. + */ + void ShareDiff(const Blob& other); + + protected: + shared_ptr data_; + shared_ptr diff_; + int num_; + int channels_; + int height_; + int width_; + int count_; + int capacity_; + + DISABLE_COPY_AND_ASSIGN(Blob); +}; // class Blob + +} // namespace caffe + +#endif // CAFFE_BLOB_HPP_ diff --git a/caffe-crfrnn/include/caffe/caffe.hpp b/caffe-crfrnn/include/caffe/caffe.hpp new file mode 100644 index 00000000..3c829f2f --- /dev/null +++ b/caffe-crfrnn/include/caffe/caffe.hpp @@ -0,0 +1,19 @@ +// caffe.hpp is the header file that you need to include in your code. It wraps +// all the internal caffe header files into one for simpler inclusion. + +#ifndef CAFFE_CAFFE_HPP_ +#define CAFFE_CAFFE_HPP_ + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/layer_factory.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +#endif // CAFFE_CAFFE_HPP_ diff --git a/caffe-crfrnn/include/caffe/common.hpp b/caffe-crfrnn/include/caffe/common.hpp new file mode 100644 index 00000000..81b2e9ae --- /dev/null +++ b/caffe-crfrnn/include/caffe/common.hpp @@ -0,0 +1,175 @@ +#ifndef CAFFE_COMMON_HPP_ +#define CAFFE_COMMON_HPP_ + +#include +#include +#include + +#include +#include // NOLINT(readability/streams) +#include // NOLINT(readability/streams) +#include +#include +#include +#include +#include // pair +#include + +#include "caffe/util/device_alternate.hpp" + +// gflags 2.1 issue: namespace google was changed to gflags without warning. +// Luckily we will be able to use GFLAGS_GFAGS_H_ to detect if it is version +// 2.1. If yes, we will add a temporary solution to redirect the namespace. +// TODO(Yangqing): Once gflags solves the problem in a more elegant way, let's +// remove the following hack. +#ifndef GFLAGS_GFLAGS_H_ +namespace gflags = google; +#endif // GFLAGS_GFLAGS_H_ + +// Disable the copy and assignment operator for a class. +#define DISABLE_COPY_AND_ASSIGN(classname) \ +private:\ + classname(const classname&);\ + classname& operator=(const classname&) + +// Instantiate a class with float and double specifications. +#define INSTANTIATE_CLASS(classname) \ + char gInstantiationGuard##classname; \ + template class classname; \ + template class classname + +#define INSTANTIATE_LAYER_GPU_FORWARD(classname) \ + template void classname::Forward_gpu( \ + const std::vector*>& bottom, \ + const std::vector*>& top); \ + template void classname::Forward_gpu( \ + const std::vector*>& bottom, \ + const std::vector*>& top); + +#define INSTANTIATE_LAYER_GPU_BACKWARD(classname) \ + template void classname::Backward_gpu( \ + const std::vector*>& top, \ + const std::vector& propagate_down, \ + const std::vector*>& bottom); \ + template void classname::Backward_gpu( \ + const std::vector*>& top, \ + const std::vector& propagate_down, \ + const std::vector*>& bottom) + +#define INSTANTIATE_LAYER_GPU_FUNCS(classname) \ + INSTANTIATE_LAYER_GPU_FORWARD(classname); \ + INSTANTIATE_LAYER_GPU_BACKWARD(classname) + +// A simple macro to mark codes that are not implemented, so that when the code +// is executed we will see a fatal log. +#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet" + +namespace caffe { + +// We will use the boost shared_ptr instead of the new C++11 one mainly +// because cuda does not work (at least now) well with C++11 features. +using boost::shared_ptr; + +// Common functions and classes from std that caffe often uses. +using std::fstream; +using std::ios; +using std::isnan; +using std::isinf; +using std::iterator; +using std::make_pair; +using std::map; +using std::ostringstream; +using std::pair; +using std::set; +using std::string; +using std::stringstream; +using std::vector; + +// A global initialization function that you should call in your main function. +// Currently it initializes google flags and google logging. +void GlobalInit(int* pargc, char*** pargv); + +// A singleton class to hold common caffe stuff, such as the handler that +// caffe is going to use for cublas, curand, etc. +class Caffe { + public: + ~Caffe(); + inline static Caffe& Get() { + if (!singleton_.get()) { + singleton_.reset(new Caffe()); + } + return *singleton_; + } + enum Brew { CPU, GPU }; + enum Phase { TRAIN, TEST }; + + + // This random number generator facade hides boost and CUDA rng + // implementation from one another (for cross-platform compatibility). + class RNG { + public: + RNG(); + explicit RNG(unsigned int seed); + explicit RNG(const RNG&); + RNG& operator=(const RNG&); + void* generator(); + private: + class Generator; + shared_ptr generator_; + }; + + // Getters for boost rng, curand, and cublas handles + inline static RNG& rng_stream() { + if (!Get().random_generator_) { + Get().random_generator_.reset(new RNG()); + } + return *(Get().random_generator_); + } +#ifndef CPU_ONLY + inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; } + inline static curandGenerator_t curand_generator() { + return Get().curand_generator_; + } +#endif + + // Returns the mode: running on CPU or GPU. + inline static Brew mode() { return Get().mode_; } + // Returns the phase: TRAIN or TEST. + inline static Phase phase() { return Get().phase_; } + // The setters for the variables + // Sets the mode. It is recommended that you don't change the mode halfway + // into the program since that may cause allocation of pinned memory being + // freed in a non-pinned way, which may cause problems - I haven't verified + // it personally but better to note it here in the header file. + inline static void set_mode(Brew mode) { Get().mode_ = mode; } + // Sets the phase. + inline static void set_phase(Phase phase) { Get().phase_ = phase; } + // Sets the random seed of both boost and curand + static void set_random_seed(const unsigned int seed); + // Sets the device. Since we have cublas and curand stuff, set device also + // requires us to reset those values. + static void SetDevice(const int device_id); + // Prints the current GPU status. + static void DeviceQuery(); + + protected: +#ifndef CPU_ONLY + cublasHandle_t cublas_handle_; + curandGenerator_t curand_generator_; +#endif + shared_ptr random_generator_; + + Brew mode_; + Phase phase_; + static shared_ptr singleton_; + + private: + // The private constructor to avoid duplicate instantiation. + Caffe(); + + DISABLE_COPY_AND_ASSIGN(Caffe); +}; + +} // namespace caffe + +#endif // CAFFE_COMMON_HPP_ diff --git a/caffe-crfrnn/include/caffe/common_layers.hpp b/caffe-crfrnn/include/caffe/common_layers.hpp new file mode 100644 index 00000000..a9ce6482 --- /dev/null +++ b/caffe-crfrnn/include/caffe/common_layers.hpp @@ -0,0 +1,510 @@ +#ifndef CAFFE_COMMON_LAYERS_HPP_ +#define CAFFE_COMMON_LAYERS_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/loss_layers.hpp" +#include "caffe/neuron_layers.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Compute the index of the @f$ K @f$ max values for each datum across + * all dimensions @f$ (C \times H \times W) @f$. + * + * Intended for use after a classification layer to produce a prediction. + * If parameter out_max_val is set to true, output is a vector of pairs + * (max_ind, max_val) for each image. + * + * NOTE: does not implement Backwards operation. + */ +template +class ArgMaxLayer : public Layer { + public: + /** + * @param param provides ArgMaxParameter argmax_param, + * with ArgMaxLayer options: + * - top_k (\b optional uint, default 1). + * the number @f$ K @f$ of maximal items to output. + * - out_max_val (\b optional bool, default false). + * if set, output a vector of pairs (max_ind, max_val) for each image. + */ + explicit ArgMaxLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_ARGMAX; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times 1 \times K \times 1) @f$ or, if out_max_val + * @f$ (N \times 2 \times K \times 1) @f$ + * the computed outputs @f$ + * y_n = \arg\max\limits_i x_{ni} + * @f$ (for @f$ K = 1 @f$). + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + /// @brief Not implemented (non-differentiable function) + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + NOT_IMPLEMENTED; + } + bool out_max_val_; + size_t top_k_; +}; + +/** + * @brief Takes at least two Blob%s and concatenates them along either the num + * or channel dimension, outputting the result. + */ +template +class ConcatLayer : public Layer { + public: + explicit ConcatLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_CONCAT; + } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + /** + * @param bottom input Blob vector (length 2+) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x_1 @f$ + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x_2 @f$ + * -# ... + * - K @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x_K @f$ + * @param top output Blob vector (length 1) + * -# @f$ (KN \times C \times H \times W) @f$ if concat_dim == 0, or + * @f$ (N \times KC \times H \times W) @f$ if concat_dim == 1: + * the concatenated output @f$ + * y = [\begin{array}{cccc} x_1 & x_2 & ... & x_K \end{array}] + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the concatenate inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (KN \times C \times H \times W) @f$ if concat_dim == 0, or + * @f$ (N \times KC \times H \times W) @f$ if concat_dim == 1: + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to concatenated outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length K), into which the top gradient + * @f$ \frac{\partial E}{\partial y} @f$ is deconcatenated back to the + * inputs @f$ + * \left[ \begin{array}{cccc} + * \frac{\partial E}{\partial x_1} & + * \frac{\partial E}{\partial x_2} & + * ... & + * \frac{\partial E}{\partial x_K} + * \end{array} \right] = + * \frac{\partial E}{\partial y} + * @f$ + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob col_bob_; + int count_; + int num_; + int channels_; + int height_; + int width_; + int concat_dim_; +}; + +/** + * @brief Compute elementwise operations, such as product and sum, + * along multiple input Blobs. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class EltwiseLayer : public Layer { + public: + explicit EltwiseLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_ELTWISE; + } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + EltwiseParameter_EltwiseOp op_; + vector coeffs_; + Blob max_idx_; + + bool stable_prod_grad_; +}; + +/** + * @brief Reshapes the input Blob into flat vectors. + * + * Note: because this layer does not change the input values -- merely the + * dimensions -- it can simply copy the input. The copy happens "virtually" + * (thus taking effectively 0 real time) by setting, in Forward, the data + * pointer of the top Blob to that of the bottom Blob (see Blob::ShareData), + * and in Backward, the diff pointer of the bottom Blob to that of the top Blob + * (see Blob::ShareDiff). + */ +template +class FlattenLayer : public Layer { + public: + explicit FlattenLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_FLATTEN; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + /** + * @param bottom input Blob vector (length 2+) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs + * @param top output Blob vector (length 1) + * -# @f$ (N \times CHW \times 1 \times 1) @f$ + * the outputs -- i.e., the (virtually) copied, flattened inputs + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the concatenate inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length K), into which the top error + * gradient is (virtually) copied + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int count_; +}; + +/** + * @brief Also known as a "fully-connected" layer, computes an inner product + * with a set of learned weights, and (optionally) adds biases. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class InnerProductLayer : public Layer { + public: + explicit InnerProductLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_INNER_PRODUCT; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int M_; + int K_; + int N_; + bool bias_term_; + Blob bias_multiplier_; +}; + +/** + * @brief Normalizes the input to have 0-mean and/or unit (1) variance. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class MVNLayer : public Layer { + public: + explicit MVNLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_MVN; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob mean_, variance_, temp_; + + /// sum_multiplier is used to carry out sum using BLAS + Blob sum_multiplier_; +}; + +/** + * @brief Ignores bottom blobs while producing no top blobs. (This is useful + * to suppress outputs during testing.) + */ +template +class SilenceLayer : public Layer { + public: + explicit SilenceLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SILENCE; + } + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 0; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top) {} + // We can't define Forward_gpu here, since STUB_GPU will provide + // its own definition for CPU_ONLY mode. + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +/** + * @brief Computes the softmax function. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class SoftmaxLayer : public Layer { + public: + explicit SoftmaxLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SOFTMAX; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// sum_multiplier is used to carry out sum using BLAS + Blob sum_multiplier_; + /// scale is an intermediate Blob to hold temporary results. + Blob scale_; +}; + +#ifdef USE_CUDNN +/** + * @brief cuDNN implementation of SoftmaxLayer. + * Fallback to SoftmaxLayer for CPU mode. + */ +template +class CuDNNSoftmaxLayer : public SoftmaxLayer { + public: + explicit CuDNNSoftmaxLayer(const LayerParameter& param) + : SoftmaxLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNSoftmaxLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; +}; +#endif + +/** + * @brief Creates a "split" path in the network by copying the bottom Blob + * into multiple top Blob%s to be used by multiple consuming layers. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class SplitLayer : public Layer { + public: + explicit SplitLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SPLIT; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int count_; +}; + +/** + * @brief Takes a Blob and slices it along either the num or channel dimension, + * outputting multiple sliced Blob results. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class SliceLayer : public Layer { + public: + explicit SliceLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SLICE; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 2; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob col_bob_; + int count_; + int num_; + int channels_; + int height_; + int width_; + int slice_dim_; + vector slice_point_; +}; + +} // namespace caffe + +#endif // CAFFE_COMMON_LAYERS_HPP_ diff --git a/caffe-crfrnn/include/caffe/data_layers.hpp b/caffe-crfrnn/include/caffe/data_layers.hpp new file mode 100644 index 00000000..3bda671c --- /dev/null +++ b/caffe-crfrnn/include/caffe/data_layers.hpp @@ -0,0 +1,339 @@ +#ifndef CAFFE_DATA_LAYERS_HPP_ +#define CAFFE_DATA_LAYERS_HPP_ + +#include +#include +#include + +#include "boost/scoped_ptr.hpp" +#include "hdf5.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/data_transformer.hpp" +#include "caffe/dataset.hpp" +#include "caffe/filler.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Provides base for data layers that feed blobs to the Net. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class BaseDataLayer : public Layer { + public: + explicit BaseDataLayer(const LayerParameter& param); + virtual ~BaseDataLayer() {} + // LayerSetUp: implements common data layer setup functionality, and calls + // DataLayerSetUp to do special data layer setup for individual layer types. + // This method may not be overridden except by the BasePrefetchingDataLayer. + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void DataLayerSetUp(const vector*>& bottom, + const vector*>& top) {} + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + + protected: + TransformationParameter transform_param_; + DataTransformer data_transformer_; + Caffe::Phase phase_; + bool output_labels_; +}; + +template +class BasePrefetchingDataLayer : + public BaseDataLayer, public InternalThread { + public: + explicit BasePrefetchingDataLayer(const LayerParameter& param) + : BaseDataLayer(param) {} + virtual ~BasePrefetchingDataLayer() {} + // LayerSetUp: implements common data layer setup functionality, and calls + // DataLayerSetUp to do special data layer setup for individual layer types. + // This method may not be overridden. + void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + virtual void CreatePrefetchThread(); + virtual void JoinPrefetchThread(); + // The thread's function + virtual void InternalThreadEntry() {} + + protected: + Blob prefetch_data_; + Blob prefetch_label_; + Blob transformed_data_; +}; + +template +class DataLayer : public BasePrefetchingDataLayer { + public: + explicit DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param) {} + virtual ~DataLayer(); + virtual void DataLayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 2; } + + protected: + virtual void InternalThreadEntry(); + + shared_ptr > dataset_; + Dataset::const_iterator iter_; +}; + +/** + * @brief Provides data to the Net generated by a Filler. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class DummyDataLayer : public Layer { + public: + explicit DummyDataLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_DUMMY_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int MinTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + + vector > > fillers_; + vector refill_; +}; + +/** + * @brief Provides data to the Net from HDF5 files. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class HDF5DataLayer : public Layer { + public: + explicit HDF5DataLayer(const LayerParameter& param) + : Layer(param) {} + virtual ~HDF5DataLayer(); + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_HDF5_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int MinTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) {} + virtual void LoadHDF5FileData(const char* filename); + + std::vector hdf_filenames_; + unsigned int num_files_; + unsigned int current_file_; + hsize_t current_row_; + std::vector > > hdf_blobs_; +}; + +/** + * @brief Write blobs to disk as HDF5 files. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class HDF5OutputLayer : public Layer { + public: + explicit HDF5OutputLayer(const LayerParameter& param); + virtual ~HDF5OutputLayer(); + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top) {} + // Data layers have no bottoms, so reshaping is trivial. + virtual void Reshape(const vector*>& bottom, + const vector*>& top) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_HDF5_OUTPUT; + } + // TODO: no limit on the number of blobs + virtual inline int ExactNumBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 0; } + + inline std::string file_name() const { return file_name_; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void SaveBlobs(); + + std::string file_name_; + hid_t file_id_; + Blob data_blob_; + Blob label_blob_; +}; + +/** + * @brief Provides data to the Net from image files. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class ImageDataLayer : public BasePrefetchingDataLayer { + public: + explicit ImageDataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param) {} + virtual ~ImageDataLayer(); + virtual void DataLayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_IMAGE_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int ExactNumTopBlobs() const { return 2; } + + protected: + shared_ptr prefetch_rng_; + virtual void ShuffleImages(); + virtual void InternalThreadEntry(); + + vector > lines_; + int lines_id_; +}; + +/** + * @brief Provides data to the Net from memory. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class MemoryDataLayer : public BaseDataLayer { + public: + explicit MemoryDataLayer(const LayerParameter& param) + : BaseDataLayer(param), has_new_data_(false) {} + virtual void DataLayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_MEMORY_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int ExactNumTopBlobs() const { return 2; } + + virtual void AddDatumVector(const vector& datum_vector); + + // Reset should accept const pointers, but can't, because the memory + // will be given to Blob, which is mutable + void Reset(Dtype* data, Dtype* label, int n); + + int batch_size() { return batch_size_; } + int channels() { return channels_; } + int height() { return height_; } + int width() { return width_; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + + int batch_size_, channels_, height_, width_, size_; + Dtype* data_; + Dtype* labels_; + int n_; + int pos_; + Blob added_data_; + Blob added_label_; + bool has_new_data_; +}; + +/** + * @brief Provides data to the Net from windows of images files, specified + * by a window data file. + * + * TODO(dox): thorough documentation for Forward and proto params. + */ +template +class WindowDataLayer : public BasePrefetchingDataLayer { + public: + explicit WindowDataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param) {} + virtual ~WindowDataLayer(); + virtual void DataLayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_WINDOW_DATA; + } + virtual inline int ExactNumBottomBlobs() const { return 0; } + virtual inline int ExactNumTopBlobs() const { return 2; } + + protected: + virtual unsigned int PrefetchRand(); + virtual void InternalThreadEntry(); + + shared_ptr prefetch_rng_; + vector > > image_database_; + enum WindowField { IMAGE_INDEX, LABEL, OVERLAP, X1, Y1, X2, Y2, NUM }; + vector > fg_windows_; + vector > bg_windows_; + Blob data_mean_; + vector mean_values_; + bool has_mean_file_; + bool has_mean_values_; + bool cache_images_; + vector > image_database_cache_; +}; + +} // namespace caffe + +#endif // CAFFE_DATA_LAYERS_HPP_ diff --git a/caffe-crfrnn/include/caffe/data_transformer.hpp b/caffe-crfrnn/include/caffe/data_transformer.hpp new file mode 100644 index 00000000..84ebba28 --- /dev/null +++ b/caffe-crfrnn/include/caffe/data_transformer.hpp @@ -0,0 +1,109 @@ +#ifndef CAFFE_DATA_TRANSFORMER_HPP +#define CAFFE_DATA_TRANSFORMER_HPP + +#ifndef OSX +#include +#endif + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Applies common transformations to the input data, such as + * scaling, mirroring, substracting the image mean... + */ +template +class DataTransformer { + public: + explicit DataTransformer(const TransformationParameter& param); + virtual ~DataTransformer() {} + + /** + * @brief Initialize the Random number generations if needed by the + * transformation. + */ + void InitRand(); + + /** + * @brief Applies the transformation defined in the data layer's + * transform_param block to the data. + * + * @param datum + * Datum containing the data to be transformed. + * @param transformed_blob + * This is destination blob. It can be part of top blob's data if + * set_cpu_data() is used. See data_layer.cpp for an example. + */ + void Transform(const Datum& datum, Blob* transformed_blob); + + /** + * @brief Applies the transformation defined in the data layer's + * transform_param block to a vector of Datum. + * + * @param datum_vector + * A vector of Datum containing the data to be transformed. + * @param transformed_blob + * This is destination blob. It can be part of top blob's data if + * set_cpu_data() is used. See memory_layer.cpp for an example. + */ + void Transform(const vector & datum_vector, + Blob* transformed_blob); + + /** + * @brief Applies the transformation defined in the data layer's + * transform_param block to a cv::Mat + * + * @param cv_img + * cv::Mat containing the data to be transformed. + * @param transformed_blob + * This is destination blob. It can be part of top blob's data if + * set_cpu_data() is used. See image_data_layer.cpp for an example. + */ +#ifndef OSX + void Transform(const cv::Mat& cv_img, Blob* transformed_blob); +#endif + + /** + * @brief Applies the same transformation defined in the data layer's + * transform_param block to all the num images in a input_blob. + * + * @param input_blob + * A Blob containing the data to be transformed. It applies the same + * transformation to all the num images in the blob. + * @param transformed_blob + * This is destination blob, it will contain as many images as the + * input blob. It can be part of top blob's data. + */ + void Transform(Blob* input_blob, Blob* transformed_blob); + + protected: + /** + * @brief Generates a random integer from Uniform({0, 1, ..., n-1}). + * + * @param n + * The upperbound (exclusive) value of the random number. + * @return + * A uniformly random integer value from ({0, 1, ..., n-1}). + */ + virtual int Rand(int n); + + void Transform(const Datum& datum, Dtype* transformed_data); + // Tranformation parameters + TransformationParameter param_; + + + shared_ptr rng_; + Caffe::Phase phase_; + Blob data_mean_; + vector mean_values_; +}; + +} // namespace caffe + +#endif // CAFFE_DATA_TRANSFORMER_HPP_ + diff --git a/caffe-crfrnn/include/caffe/dataset.hpp b/caffe-crfrnn/include/caffe/dataset.hpp new file mode 100644 index 00000000..1dd8458c --- /dev/null +++ b/caffe-crfrnn/include/caffe/dataset.hpp @@ -0,0 +1,241 @@ +#ifndef CAFFE_DATASET_H_ +#define CAFFE_DATASET_H_ + +#include + +#include +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +namespace dataset_internal { + +using google::protobuf::Message; + +template +struct static_assertion {}; +template<> +struct static_assertion { + enum { + DEFAULT_CODER_NOT_AVAILABLE + }; +}; + +template +struct DefaultCoder { + using static_assertion::DEFAULT_CODER_NOT_AVAILABLE; + static bool serialize(const T& obj, string* serialized); + static bool serialize(const T& obj, vector* serialized); + static bool deserialize(const string& serialized, T* obj); + static bool deserialize(const char* data, size_t size, T* obj); +}; + +template <> +struct DefaultCoder { + static bool serialize(const Message& obj, string* serialized) { + return obj.SerializeToString(serialized); + } + + static bool serialize(const Message& obj, vector* serialized) { + serialized->resize(obj.ByteSize()); + return obj.SerializeWithCachedSizesToArray( + reinterpret_cast(serialized->data())); + } + + static bool deserialize(const string& serialized, Message* obj) { + return obj->ParseFromString(serialized); + } + + static bool deserialize(const char* data, size_t size, Message* obj) { + return obj->ParseFromArray(data, size); + } +}; + +template <> +struct DefaultCoder : public DefaultCoder { }; + +template <> +struct DefaultCoder { + static bool serialize(string obj, string* serialized) { + *serialized = obj; + return true; + } + + static bool serialize(const string& obj, vector* serialized) { + vector temp(obj.data(), obj.data() + obj.size()); + serialized->swap(temp); + return true; + } + + static bool deserialize(const string& serialized, string* obj) { + *obj = serialized; + return true; + } + + static bool deserialize(const char* data, size_t size, string* obj) { + string temp_string(data, size); + obj->swap(temp_string); + return true; + } +}; + +template <> +struct DefaultCoder > { + static bool serialize(vector obj, string* serialized) { + string tmp(obj.data(), obj.size()); + serialized->swap(tmp); + return true; + } + + static bool serialize(const vector& obj, vector* serialized) { + *serialized = obj; + return true; + } + + static bool deserialize(const string& serialized, vector* obj) { + vector tmp(serialized.data(), serialized.data() + serialized.size()); + obj->swap(tmp); + return true; + } + + static bool deserialize(const char* data, size_t size, vector* obj) { + vector tmp(data, data + size); + obj->swap(tmp); + return true; + } +}; + +} // namespace dataset_internal + +template , + typename VCoder = dataset_internal::DefaultCoder > +class Dataset { + public: + enum Mode { + New, + ReadWrite, + ReadOnly + }; + + typedef K key_type; + typedef V value_type; + + struct KV { + K key; + V value; + }; + + virtual bool open(const string& filename, Mode mode) = 0; + virtual bool put(const K& key, const V& value) = 0; + virtual bool get(const K& key, V* value) = 0; + virtual bool first_key(K* key) = 0; + virtual bool last_key(K* key) = 0; + virtual bool commit() = 0; + virtual void close() = 0; + + virtual void keys(vector* keys) = 0; + + Dataset() { } + virtual ~Dataset() { } + + class iterator; + typedef iterator const_iterator; + + virtual const_iterator begin() const = 0; + virtual const_iterator cbegin() const = 0; + virtual const_iterator end() const = 0; + virtual const_iterator cend() const = 0; + + protected: + class DatasetState; + + public: + class iterator : public std::iterator { + public: + typedef KV T; + typedef T value_type; + typedef T& reference_type; + typedef T* pointer_type; + + iterator() + : parent_(NULL) { } + + iterator(const Dataset* parent, shared_ptr state) + : parent_(parent), + state_(state) { } + + iterator(const iterator& other) + : parent_(other.parent_), + state_(other.state_ ? other.state_->clone() + : shared_ptr()) { } + + iterator& operator=(iterator copy) { + copy.swap(*this); + return *this; + } + + void swap(iterator& other) throw() { + std::swap(this->parent_, other.parent_); + std::swap(this->state_, other.state_); + } + + bool operator==(const iterator& other) const { + return parent_->equal(state_, other.state_); + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + iterator& operator++() { + parent_->increment(&state_); + return *this; + } + iterator operator++(int) { + iterator copy(*this); + parent_->increment(&state_); + return copy; + } + + reference_type operator*() const { + return parent_->dereference(state_); + } + + pointer_type operator->() const { + return &parent_->dereference(state_); + } + + protected: + const Dataset* parent_; + shared_ptr state_; + }; + + protected: + class DatasetState { + public: + virtual ~DatasetState() { } + virtual shared_ptr clone() = 0; + }; + + virtual bool equal(shared_ptr state1, + shared_ptr state2) const = 0; + virtual void increment(shared_ptr* state) const = 0; + virtual KV& dereference( + shared_ptr state) const = 0; +}; + +} // namespace caffe + +#define INSTANTIATE_DATASET(type) \ + template class type; \ + template class type >; \ + template class type; + +#endif // CAFFE_DATASET_H_ diff --git a/caffe-crfrnn/include/caffe/dataset_factory.hpp b/caffe-crfrnn/include/caffe/dataset_factory.hpp new file mode 100644 index 00000000..57db49bf --- /dev/null +++ b/caffe-crfrnn/include/caffe/dataset_factory.hpp @@ -0,0 +1,20 @@ +#ifndef CAFFE_DATASET_FACTORY_H_ +#define CAFFE_DATASET_FACTORY_H_ + +#include + +#include "caffe/common.hpp" +#include "caffe/dataset.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template +shared_ptr > DatasetFactory(const DataParameter_DB& type); + +template +shared_ptr > DatasetFactory(const string& type); + +} // namespace caffe + +#endif // CAFFE_DATASET_FACTORY_H_ diff --git a/caffe-crfrnn/include/caffe/filler.hpp b/caffe-crfrnn/include/caffe/filler.hpp new file mode 100644 index 00000000..136ce958 --- /dev/null +++ b/caffe-crfrnn/include/caffe/filler.hpp @@ -0,0 +1,188 @@ +// Fillers are random number generators that fills a blob using the specified +// algorithm. The expectation is that they are only going to be used during +// initialization time and will not involve any GPUs. + +#ifndef CAFFE_FILLER_HPP +#define CAFFE_FILLER_HPP + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +/// @brief Fills a Blob with constant or randomly-generated data. +template +class Filler { + public: + explicit Filler(const FillerParameter& param) : filler_param_(param) {} + virtual ~Filler() {} + virtual void Fill(Blob* blob) = 0; + protected: + FillerParameter filler_param_; +}; // class Filler + + +/// @brief Fills a Blob with constant values @f$ x = 0 @f$. +template +class ConstantFiller : public Filler { + public: + explicit ConstantFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + Dtype* data = blob->mutable_cpu_data(); + const int count = blob->count(); + const Dtype value = this->filler_param_.value(); + CHECK(count); + for (int i = 0; i < count; ++i) { + data[i] = value; + } + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; + } +}; + +/// @brief Fills a Blob with uniformly distributed values @f$ x\sim U(a, b) @f$. +template +class UniformFiller : public Filler { + public: + explicit UniformFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + CHECK(blob->count()); + caffe_rng_uniform(blob->count(), Dtype(this->filler_param_.min()), + Dtype(this->filler_param_.max()), blob->mutable_cpu_data()); + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; + } +}; + +/// @brief Fills a Blob with Gaussian-distributed values @f$ x = a @f$. +template +class GaussianFiller : public Filler { + public: + explicit GaussianFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + Dtype* data = blob->mutable_cpu_data(); + CHECK(blob->count()); + caffe_rng_gaussian(blob->count(), Dtype(this->filler_param_.mean()), + Dtype(this->filler_param_.std()), blob->mutable_cpu_data()); + int sparse = this->filler_param_.sparse(); + CHECK_GE(sparse, -1); + if (sparse >= 0) { + // Sparse initialization is implemented for "weight" blobs; i.e. matrices. + // These have num == channels == 1; height is number of inputs; width is + // number of outputs. The 'sparse' variable specifies the mean number + // of non-zero input weights for a given output. + CHECK_EQ(blob->num(), 1); + CHECK_EQ(blob->channels(), 1); + int num_inputs = blob->height(); + Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs); + rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int))); + int* mask = reinterpret_cast(rand_vec_->mutable_cpu_data()); + caffe_rng_bernoulli(blob->count(), non_zero_probability, mask); + for (int i = 0; i < blob->count(); ++i) { + data[i] *= mask[i]; + } + } + } + + protected: + shared_ptr rand_vec_; +}; + +/** @brief Fills a Blob with values @f$ x \in [0, 1] @f$ + * such that @f$ \forall i \sum_j x_{ij} = 1 @f$. + */ +template +class PositiveUnitballFiller : public Filler { + public: + explicit PositiveUnitballFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + Dtype* data = blob->mutable_cpu_data(); + DCHECK(blob->count()); + caffe_rng_uniform(blob->count(), 0, 1, blob->mutable_cpu_data()); + // We expect the filler to not be called very frequently, so we will + // just use a simple implementation + int dim = blob->count() / blob->num(); + CHECK(dim); + for (int i = 0; i < blob->num(); ++i) { + Dtype sum = 0; + for (int j = 0; j < dim; ++j) { + sum += data[i * dim + j]; + } + for (int j = 0; j < dim; ++j) { + data[i * dim + j] /= sum; + } + } + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; + } +}; + +/** + * @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$ + * is set inversely proportional to the number of incoming nodes. + * + * A Filler based on the paper [Bengio and Glorot 2010]: Understanding + * the difficulty of training deep feedforward neuralnetworks, but does not + * use the fan_out value. + * + * It fills the incoming matrix by randomly sampling uniform data from + * [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number + * of input nodes. You should make sure the input blob has shape (num, a, b, c) + * where a * b * c = fan_in. + * + * TODO(dox): make notation in above comment consistent with rest & use LaTeX. + */ +template +class XavierFiller : public Filler { + public: + explicit XavierFiller(const FillerParameter& param) + : Filler(param) {} + virtual void Fill(Blob* blob) { + CHECK(blob->count()); + int fan_in = blob->count() / blob->num(); + Dtype scale = sqrt(Dtype(3) / fan_in); + caffe_rng_uniform(blob->count(), -scale, scale, + blob->mutable_cpu_data()); + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; + } +}; + + +/** + * @brief Get a specific filler from the specification given in FillerParameter. + * + * Ideally this would be replaced by a factory pattern, but we will leave it + * this way for now. + */ +template +Filler* GetFiller(const FillerParameter& param) { + const std::string& type = param.type(); + if (type == "constant") { + return new ConstantFiller(param); + } else if (type == "gaussian") { + return new GaussianFiller(param); + } else if (type == "positive_unitball") { + return new PositiveUnitballFiller(param); + } else if (type == "uniform") { + return new UniformFiller(param); + } else if (type == "xavier") { + return new XavierFiller(param); + } else { + CHECK(false) << "Unknown filler name: " << param.type(); + } + return (Filler*)(NULL); +} + +} // namespace caffe + +#endif // CAFFE_FILLER_HPP_ diff --git a/caffe-crfrnn/include/caffe/internal_thread.hpp b/caffe-crfrnn/include/caffe/internal_thread.hpp new file mode 100644 index 00000000..6a106e6e --- /dev/null +++ b/caffe-crfrnn/include/caffe/internal_thread.hpp @@ -0,0 +1,50 @@ +#ifndef CAFFE_INTERNAL_THREAD_HPP_ +#define CAFFE_INTERNAL_THREAD_HPP_ + +#include "caffe/common.hpp" + +namespace caffe { + +/** + * A minimal wrapper for boost::thread to force host compilation for boost + * Defined in caffe/util/thread.hpp + */ +class Thread { + public: + template + Thread(Callable func, A1 a1); + void join(); + bool joinable(); + private: + void* thread_; +}; + +/** + * Virtual class encapsulate boost::thread for use in base class + * The child class will acquire the ability to run a single thread, + * by reimplementing the virutal function InternalThreadEntry. + */ +class InternalThread { + public: + InternalThread() : thread_(NULL) {} + virtual ~InternalThread(); + + /** Returns true if the thread was successfully started. **/ + bool StartInternalThread(); + + /** Will not return until the internal thread has exited. */ + bool WaitForInternalThreadToExit(); + + bool is_started() const { return thread_ != NULL && thread_->joinable(); } + + protected: + /* Implement this method in your subclass + with the code you want your thread to run. */ + virtual void InternalThreadEntry() {} + + caffe::Thread* thread_; +}; + +} // namespace caffe + +#endif diff --git a/caffe-crfrnn/include/caffe/layer.hpp b/caffe-crfrnn/include/caffe/layer.hpp new file mode 100644 index 00000000..6c768b9b --- /dev/null +++ b/caffe-crfrnn/include/caffe/layer.hpp @@ -0,0 +1,492 @@ +#ifndef CAFFE_LAYER_H_ +#define CAFFE_LAYER_H_ + +#include +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer_factory.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/coords.hpp" +#include "caffe/util/device_alternate.hpp" + +namespace caffe { + +template class Net; + +/** + * @brief An interface for the units of computation which can be composed into a + * Net. + * + * Layer&s must implement a Forward function, in which they take their input + * (bottom) Blob&s (if any) and compute their output Blob&s (if any). + * They may also implement a Backward function, in which they compute the error + * gradients with respect to their input Blob&s, given the error gradients with + * their output Blob&s. + */ +template +class Layer { + public: + /** + * You should not implement your own constructor. Any set up code should go + * to SetUp(), where the dimensions of the bottom blobs are provided to the + * layer. + */ + explicit Layer(const LayerParameter& param) + : layer_param_(param) { + // The only thing we do is to copy blobs if there are any. + if (layer_param_.blobs_size() > 0) { + blobs_.resize(layer_param_.blobs_size()); + for (int i = 0; i < layer_param_.blobs_size(); ++i) { + blobs_[i].reset(new Blob()); + blobs_[i]->FromProto(layer_param_.blobs(i)); + } + } + } + virtual ~Layer() {} + + /** + * @brief Implements common layer setup functionality. + * + * @param bottom the preshaped input blobs + * @param top + * the allocated but unshaped output blobs, to be shaped by Reshape + * + * Checks that the number of bottom and top blobs is correct. + * Calls LayerSetUp to do special layer setup for individual layer types, + * followed by Reshape to set up sizes of top blobs and internal buffers. + * Sets up the loss weight multiplier blobs for any non-zero loss weights. + * This method may not be overridden. + */ + void SetUp(const vector*>& bottom, + const vector*>& top) { + CheckBlobCounts(bottom, top); + LayerSetUp(bottom, top); + Reshape(bottom, top); + SetLossWeights(top); + } + + /** + * @brief Does layer-specific setup: your layer should implement this function + * as well as Reshape. + * + * @param bottom + * the preshaped input blobs, whose data fields store the input data for + * this layer + * @param top + * the allocated but unshaped output blobs + * + * This method should do one-time layer specific setup. This includes reading + * and processing relevent parameters from the layer_param_. + * Setting up the shapes of top blobs and internal buffers should be done in + * Reshape, which will be called before the forward pass to + * adjust the top blob sizes. + */ + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top) {} + + /** + * @brief Adjust the shapes of top blobs and internal buffers to accomodate + * the shapes of the bottom blobs. + * + * @param bottom the input blobs, with the requested input shapes + * @param top the top blobs, which should be reshaped as needed + * + * This method should reshape top blobs as needed according to the shapes + * of the bottom (input) blobs, as well as reshaping any internal buffers + * and making any other necessary adjustments so that the layer can + * accomodate the bottom blobs. + */ + virtual void Reshape(const vector*>& bottom, + const vector*>& top) = 0; + + /** + * @brief Given the bottom blobs, compute the top blobs and the loss. + * + * @param bottom + * the input blobs, whose data fields store the input data for this layer + * @param top + * the preshaped output blobs, whose data fields will store this layers' + * outputs + * \return The total loss from the layer. + * + * The Forward wrapper calls the relevant device wrapper function + * (Forward_cpu or Forward_gpu) to compute the top blob values given the + * bottom blobs. If the layer has any non-zero loss_weights, the wrapper + * then computes and returns the loss. + * + * Your layer should implement Forward_cpu and (optionally) Forward_gpu. + */ + inline Dtype Forward(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Given the top blob error gradients, compute the bottom blob error + * gradients. + * + * @param top + * the output blobs, whose diff fields store the gradient of the error + * with respect to themselves + * @param propagate_down + * a vector with equal length to bottom, with each index indicating + * whether to propagate the error gradients down to the bottom blob at + * the corresponding index + * @param bottom + * the input blobs, whose diff fields will store the gradient of the error + * with respect to themselves after Backward is run + * + * The Backward wrapper calls the relevant device wrapper function + * (Backward_cpu or Backward_gpu) to compute the bottom blob diffs given the + * top blob diffs. + * + * Your layer should implement Forward_cpu and (optionally) Forward_gpu. + */ + inline void Backward(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom); + + /** + * @brief Returns the vector of learnable parameter blobs. + */ + vector > >& blobs() { + return blobs_; + } + + /** + * @brief Returns the layer parameter. + */ + const LayerParameter& layer_param() const { return layer_param_; } + + /** + * @brief Writes the layer parameter to a protocol buffer + */ + virtual void ToProto(LayerParameter* param, bool write_diff = false); + + /** + * @brief Returns the scalar loss associated with a top blob at a given index. + */ + inline Dtype loss(const int top_index) const { + return (loss_.size() > top_index) ? loss_[top_index] : Dtype(0); + } + + /** + * @brief Sets the loss associated with a top blob at a given index. + */ + inline void set_loss(const int top_index, const Dtype value) { + if (loss_.size() <= top_index) { + loss_.resize(top_index + 1, Dtype(0)); + } + loss_[top_index] = value; + } + + /** + * @brief Returns the layer type as an enum value. + */ + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_NONE; + } + + /** + * @brief Returns the layer type name. + */ + virtual inline const string& type_name() const { + return LayerParameter_LayerType_Name(type()); + } + + /** + * @brief Returns the exact number of bottom blobs required by the layer, + * or -1 if no exact number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some exact number of bottom blobs. + */ + virtual inline int ExactNumBottomBlobs() const { return -1; } + /** + * @brief Returns the minimum number of bottom blobs required by the layer, + * or -1 if no minimum number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some minimum number of bottom blobs. + */ + virtual inline int MinBottomBlobs() const { return -1; } + /** + * @brief Returns the maximum number of bottom blobs required by the layer, + * or -1 if no maximum number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some maximum number of bottom blobs. + */ + virtual inline int MaxBottomBlobs() const { return -1; } + /** + * @brief Returns the exact number of top blobs required by the layer, + * or -1 if no exact number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some exact number of top blobs. + */ + virtual inline int ExactNumTopBlobs() const { return -1; } + /** + * @brief Returns the minimum number of top blobs required by the layer, + * or -1 if no minimum number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some minimum number of top blobs. + */ + virtual inline int MinTopBlobs() const { return -1; } + /** + * @brief Returns the maximum number of top blobs required by the layer, + * or -1 if no maximum number is required. + * + * This method should be overridden to return a non-negative value if your + * layer expects some maximum number of top blobs. + */ + virtual inline int MaxTopBlobs() const { return -1; } + /** + * @brief Returns true if the layer requires an equal number of bottom and + * top blobs. + * + * This method should be overridden to return true if your layer expects an + * equal number of bottom and top blobs. + */ + virtual inline bool EqualNumBottomTopBlobs() const { return false; } + + /** + * @brief Return whether "anonymous" top blobs are created automatically + * by the layer. + * + * If this method returns true, Net::Init will create enough "anonymous" top + * blobs to fulfill the requirement specified by ExactNumTopBlobs() or + * MinTopBlobs(). + */ + virtual inline bool AutoTopBlobs() const { return false; } + + /** + * @brief Return whether to allow force_backward for a given bottom blob + * index. + * + * If AllowForceBackward(i) == false, we will ignore the force_backward + * setting and backpropagate to blob i only if it needs gradient information + * (as is done when force_backward == false). + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return true; + } + + /** + * @brief Specifies whether the layer should compute gradients w.r.t. a + * parameter at a particular index given by param_id. + * + * You can safely ignore false values and always compute gradients + * for all parameters, but possibly with wasteful computation. + */ + inline bool param_propagate_down(const int param_id) { + return (param_propagate_down_.size() > param_id) ? + param_propagate_down_[param_id] : false; + } + /** + * @brief Sets whether the layer should compute gradients w.r.t. a + * parameter at a particular index given by param_id. + */ + inline void set_param_propagate_down(const int param_id, const bool value) { + if (param_propagate_down_.size() <= param_id) { + param_propagate_down_.resize(param_id + 1, true); + } + param_propagate_down_[param_id] = value; + } + + virtual DiagonalAffineMap coord_map() { + NOT_IMPLEMENTED; + // suppress warnings + return DiagonalAffineMap(vector >()); + } + + /** + * @brief Used by Net to give layers a pointer to their owning net. + */ + void set_net(Net* net) { net_ = net; } + + protected: + /** The protobuf that stores the layer parameters */ + LayerParameter layer_param_; + /** The vector that stores the learnable parameters as a set of blobs. */ + vector > > blobs_; + /** Vector indicating whether to compute the diff of each param blob. */ + vector param_propagate_down_; + + /** The vector that indicates whether each top blob has a non-zero weight in + * the objective function. */ + vector loss_; + + /** The net to which this layer belongs. */ + Net* net_; + + /** @brief Using the CPU device, compute the layer output. */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top) = 0; + /** + * @brief Using the GPU device, compute the layer output. + * Fall back to Forward_cpu() if unavailable. + */ + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top) { + // LOG(WARNING) << "Using CPU code as backup."; + return Forward_cpu(bottom, top); + } + + /** + * @brief Using the CPU device, compute the gradients for any parameters and + * for the bottom blobs if propagate_down is true. + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) = 0; + /** + * @brief Using the GPU device, compute the gradients for any parameters and + * for the bottom blobs if propagate_down is true. + * Fall back to Backward_cpu() if unavailable. + */ + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + // LOG(WARNING) << "Using CPU code as backup."; + Backward_cpu(top, propagate_down, bottom); + } + + /** + * Called by the parent Layer's SetUp to check that the number of bottom + * and top Blobs provided as input match the expected numbers specified by + * the {ExactNum,Min,Max}{Bottom,Top}Blobs() functions. + */ + virtual void CheckBlobCounts(const vector*>& bottom, + const vector*>& top) { + if (ExactNumBottomBlobs() >= 0) { + CHECK_EQ(ExactNumBottomBlobs(), bottom.size()) + << type_name() << " Layer takes " << ExactNumBottomBlobs() + << " bottom blob(s) as input."; + } + if (MinBottomBlobs() >= 0) { + CHECK_LE(MinBottomBlobs(), bottom.size()) + << type_name() << " Layer takes at least " << MinBottomBlobs() + << " bottom blob(s) as input."; + } + if (MaxBottomBlobs() >= 0) { + CHECK_GE(MaxBottomBlobs(), bottom.size()) + << type_name() << " Layer takes at most " << MaxBottomBlobs() + << " bottom blob(s) as input."; + } + if (ExactNumTopBlobs() >= 0) { + CHECK_EQ(ExactNumTopBlobs(), top.size()) + << type_name() << " Layer produces " << ExactNumTopBlobs() + << " top blob(s) as output."; + } + if (MinTopBlobs() >= 0) { + CHECK_LE(MinTopBlobs(), top.size()) + << type_name() << " Layer produces at least " << MinTopBlobs() + << " top blob(s) as output."; + } + if (MaxTopBlobs() >= 0) { + CHECK_GE(MaxTopBlobs(), top.size()) + << type_name() << " Layer produces at most " << MaxTopBlobs() + << " top blob(s) as output."; + } + if (EqualNumBottomTopBlobs()) { + CHECK_EQ(bottom.size(), top.size()) + << type_name() << " Layer produces one top blob as output for each " + << "bottom blob input."; + } + } + + /** + * Called by SetUp to initialize the weights associated with any top blobs in + * the loss function. Store non-zero loss weights in the diff blob. + */ + inline void SetLossWeights(const vector*>& top) { + const int num_loss_weights = layer_param_.loss_weight_size(); + if (num_loss_weights) { + CHECK_EQ(top.size(), num_loss_weights) << "loss_weight must be " + "unspecified or specified once per top blob."; + for (int top_id = 0; top_id < top.size(); ++top_id) { + const Dtype loss_weight = layer_param_.loss_weight(top_id); + if (loss_weight == Dtype(0)) { continue; } + this->set_loss(top_id, loss_weight); + const int count = top[top_id]->count(); + Dtype* loss_multiplier = top[top_id]->mutable_cpu_diff(); + caffe_set(count, loss_weight, loss_multiplier); + } + } + } + + DISABLE_COPY_AND_ASSIGN(Layer); +}; // class Layer + +// Forward and backward wrappers. You should implement the cpu and +// gpu specific implementations instead, and should not change these +// functions. +template +inline Dtype Layer::Forward(const vector*>& bottom, + const vector*>& top) { + Dtype loss = 0; + switch (Caffe::mode()) { + case Caffe::CPU: + Forward_cpu(bottom, top); + for (int top_id = 0; top_id < top.size(); ++top_id) { + if (!this->loss(top_id)) { continue; } + const int count = top[top_id]->count(); + const Dtype* data = top[top_id]->cpu_data(); + const Dtype* loss_weights = top[top_id]->cpu_diff(); + loss += caffe_cpu_dot(count, data, loss_weights); + } + break; + case Caffe::GPU: + Forward_gpu(bottom, top); +#ifndef CPU_ONLY + for (int top_id = 0; top_id < top.size(); ++top_id) { + if (!this->loss(top_id)) { continue; } + const int count = top[top_id]->count(); + const Dtype* data = top[top_id]->gpu_data(); + const Dtype* loss_weights = top[top_id]->gpu_diff(); + Dtype blob_loss = 0; + caffe_gpu_dot(count, data, loss_weights, &blob_loss); + loss += blob_loss; + } +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode."; + } + return loss; +} + +template +inline void Layer::Backward(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + switch (Caffe::mode()) { + case Caffe::CPU: + Backward_cpu(top, propagate_down, bottom); + break; + case Caffe::GPU: + Backward_gpu(top, propagate_down, bottom); + break; + default: + LOG(FATAL) << "Unknown caffe mode."; + } +} + +// Serialize LayerParameter to protocol buffer +template +void Layer::ToProto(LayerParameter* param, bool write_diff) { + param->Clear(); + param->CopyFrom(layer_param_); + param->clear_blobs(); + for (int i = 0; i < blobs_.size(); ++i) { + blobs_[i]->ToProto(param->add_blobs(), write_diff); + } +} + +} // namespace caffe + +#endif // CAFFE_LAYER_H_ diff --git a/caffe-crfrnn/include/caffe/layer_factory.hpp b/caffe-crfrnn/include/caffe/layer_factory.hpp new file mode 100644 index 00000000..c1fd6aa0 --- /dev/null +++ b/caffe-crfrnn/include/caffe/layer_factory.hpp @@ -0,0 +1,118 @@ +/** + * @brief A layer factory that allows one to register layers. + * During runtime, registered layers could be called by passing a LayerParameter + * protobuffer to the CreateLayer function: + * + * LayerRegistry::CreateLayer(param); + * + * There are two ways to register a layer. Assuming that we have a layer like: + * + * template + * class MyAwesomeLayer : public Layer { + * // your implementations + * }; + * + * and its type is defined in the protobuffer as + * + * enum LayerType { + * // other definitions + * AWESOME = 46, + * } + * + * If the layer is going to be created simply by its constructor, in your c++ + * file, add the following line: + * + * REGISTER_LAYER_CLASS(AWESOME, MyAwesomeLayer); + * + * Or, if the layer is going to be created by another creator function, in the + * format of: + * + * template + * Layer GetMyAwesomeLayer(const LayerParameter& param) { + * // your implementation + * } + * + * (for example, when your layer has multiple backends, see GetConvolutionLayer + * for a use case), then you can register the creator function instead, like + * + * REGISTER_LAYER_CREATOR(AWESOME, GetMyAwesomeLayer) + * + * Note that each layer type should only be registered once. + */ + +#ifndef CAFFE_LAYER_FACTORY_H_ +#define CAFFE_LAYER_FACTORY_H_ + +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template +class Layer; + +template +class LayerRegistry { + public: + typedef Layer* (*Creator)(const LayerParameter&); + typedef std::map CreatorRegistry; + + static CreatorRegistry& Registry() { + static CreatorRegistry* g_registry_ = new CreatorRegistry(); + return *g_registry_; + } + + // Adds a creator. + static void AddCreator(const LayerParameter_LayerType& type, + Creator creator) { + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 0) + << "Layer type " << type << " already registered."; + registry[type] = creator; + } + + // Get a layer using a LayerParameter. + static Layer* CreateLayer(const LayerParameter& param) { + LOG(INFO) << "Creating layer " << param.name(); + const LayerParameter_LayerType& type = param.type(); + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 1); + return registry[type](param); + } + + private: + // Layer registry should never be instantiated - everything is done with its + // static variables. + LayerRegistry() {} +}; + + +template +class LayerRegisterer { + public: + LayerRegisterer(const LayerParameter_LayerType& type, + Layer* (*creator)(const LayerParameter&)) { + // LOG(INFO) << "Registering layer type: " << type; + LayerRegistry::AddCreator(type, creator); + } +}; + + +#define REGISTER_LAYER_CREATOR(type, creator) \ + static LayerRegisterer g_creator_f_##type( \ + LayerParameter_LayerType_##type, creator); \ + static LayerRegisterer g_creator_d_##type( \ + LayerParameter_LayerType_##type, creator) + +#define REGISTER_LAYER_CLASS(type, clsname) \ + template \ + Layer* Creator_##clsname(const LayerParameter& param) { \ + return new clsname(param); \ + } \ + REGISTER_LAYER_CREATOR(type, Creator_##clsname) + +} // namespace caffe + +#endif // CAFFE_LAYER_FACTORY_H_ diff --git a/caffe-crfrnn/include/caffe/leveldb_dataset.hpp b/caffe-crfrnn/include/caffe/leveldb_dataset.hpp new file mode 100644 index 00000000..d58c181b --- /dev/null +++ b/caffe-crfrnn/include/caffe/leveldb_dataset.hpp @@ -0,0 +1,90 @@ +#ifndef CAFFE_LEVELDB_DATASET_H_ +#define CAFFE_LEVELDB_DATASET_H_ + +#include +#include + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/dataset.hpp" + +namespace caffe { + +template , + typename VCoder = dataset_internal::DefaultCoder > +class LeveldbDataset : public Dataset { + public: + typedef Dataset Base; + typedef typename Base::key_type key_type; + typedef typename Base::value_type value_type; + typedef typename Base::DatasetState DatasetState; + typedef typename Base::Mode Mode; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::KV KV; + + bool open(const string& filename, Mode mode); + bool put(const K& key, const V& value); + bool get(const K& key, V* value); + bool first_key(K* key); + bool last_key(K* key); + bool commit(); + void close(); + + void keys(vector* keys); + + const_iterator begin() const; + const_iterator cbegin() const; + const_iterator end() const; + const_iterator cend() const; + + protected: + class LeveldbState : public DatasetState { + public: + explicit LeveldbState(shared_ptr db, + shared_ptr iter) + : DatasetState(), + db_(db), + iter_(iter) { } + + ~LeveldbState() { + // This order is very important. + // Iterators must be destroyed before their associated DB + // is destroyed. + iter_.reset(); + db_.reset(); + } + + shared_ptr clone() { + shared_ptr new_iter; + + CHECK(iter_.get()); + new_iter.reset(db_->NewIterator(leveldb::ReadOptions())); + CHECK(iter_->Valid()); + new_iter->Seek(iter_->key()); + CHECK(new_iter->Valid()); + + return shared_ptr(new LeveldbState(db_, new_iter)); + } + + shared_ptr db_; + shared_ptr iter_; + KV kv_pair_; + }; + + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr* state) const; + KV& dereference(shared_ptr state) const; + + shared_ptr db_; + shared_ptr batch_; + bool read_only_; +}; + +} // namespace caffe + +#endif // CAFFE_LEVELDB_DATASET_H_ diff --git a/caffe-crfrnn/include/caffe/lmdb_dataset.hpp b/caffe-crfrnn/include/caffe/lmdb_dataset.hpp new file mode 100644 index 00000000..ac1e5ee2 --- /dev/null +++ b/caffe-crfrnn/include/caffe/lmdb_dataset.hpp @@ -0,0 +1,95 @@ +#ifndef CAFFE_LMDB_DATASET_H_ +#define CAFFE_LMDB_DATASET_H_ + +#include +#include +#include + +#include "lmdb.h" + +#include "caffe/common.hpp" +#include "caffe/dataset.hpp" + +namespace caffe { + +template , + typename VCoder = dataset_internal::DefaultCoder > +class LmdbDataset : public Dataset { + public: + typedef Dataset Base; + typedef typename Base::key_type key_type; + typedef typename Base::value_type value_type; + typedef typename Base::DatasetState DatasetState; + typedef typename Base::Mode Mode; + typedef typename Base::const_iterator const_iterator; + typedef typename Base::KV KV; + + LmdbDataset() + : env_(NULL), + dbi_(0), + write_txn_(NULL), + read_txn_(NULL) { } + + bool open(const string& filename, Mode mode); + bool put(const K& key, const V& value); + bool get(const K& key, V* value); + bool first_key(K* key); + bool last_key(K* key); + bool commit(); + void close(); + + void keys(vector* keys); + + const_iterator begin() const; + const_iterator cbegin() const; + const_iterator end() const; + const_iterator cend() const; + + protected: + class LmdbState : public DatasetState { + public: + explicit LmdbState(MDB_cursor* cursor, MDB_txn* txn, const MDB_dbi* dbi) + : DatasetState(), + cursor_(cursor), + txn_(txn), + dbi_(dbi) { } + + shared_ptr clone() { + CHECK(cursor_); + + MDB_cursor* new_cursor; + int retval; + + retval = mdb_cursor_open(txn_, *dbi_, &new_cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + MDB_val key; + MDB_val val; + retval = mdb_cursor_get(cursor_, &key, &val, MDB_GET_CURRENT); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + retval = mdb_cursor_get(new_cursor, &key, &val, MDB_SET); + CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); + + return shared_ptr(new LmdbState(new_cursor, txn_, dbi_)); + } + + MDB_cursor* cursor_; + MDB_txn* txn_; + const MDB_dbi* dbi_; + KV kv_pair_; + }; + + bool equal(shared_ptr state1, + shared_ptr state2) const; + void increment(shared_ptr* state) const; + KV& dereference(shared_ptr state) const; + + MDB_env* env_; + MDB_dbi dbi_; + MDB_txn* write_txn_; + MDB_txn* read_txn_; +}; + +} // namespace caffe + +#endif // CAFFE_LMDB_DATASET_H_ diff --git a/caffe-crfrnn/include/caffe/loss_layers.hpp b/caffe-crfrnn/include/caffe/loss_layers.hpp new file mode 100644 index 00000000..6729cc4f --- /dev/null +++ b/caffe-crfrnn/include/caffe/loss_layers.hpp @@ -0,0 +1,775 @@ +#ifndef CAFFE_LOSS_LAYERS_HPP_ +#define CAFFE_LOSS_LAYERS_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/neuron_layers.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +const float kLOG_THRESHOLD = 1e-20; + +/** + * @brief Computes the classification accuracy for a one-of-many + * classification task. + */ +template +class AccuracyLayer : public Layer { + public: + /** + * @param param provides AccuracyParameter accuracy_param, + * with AccuracyLayer options: + * - top_k (\b optional, default 1). + * Sets the maximum rank @f$ k @f$ at which a prediction is considered + * correct. For example, if @f$ k = 5 @f$, a prediction is counted + * correct if the correct label is among the top 5 predicted labels. + */ + explicit AccuracyLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_ACCURACY; + } + + virtual inline int ExactNumBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + /** + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ x @f$, a Blob with values in + * @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of + * the @f$ K = CHW @f$ classes. Each @f$ x_n @f$ is mapped to a predicted + * label @f$ \hat{l}_n @f$ given by its maximal index: + * @f$ \hat{l}_n = \arg\max\limits_k x_{nk} @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels @f$ l @f$, an integer-valued Blob with values + * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ + * indicating the correct class label among the @f$ K @f$ classes + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed accuracy: @f$ + * \frac{1}{N} \sum\limits_{n=1}^N \delta\{ \hat{l}_n = l_n \} + * @f$, where @f$ + * \delta\{\mathrm{condition}\} = \left\{ + * \begin{array}{lr} + * 1 & \mbox{if condition} \\ + * 0 & \mbox{otherwise} + * \end{array} \right. + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + + + /// @brief Not implemented -- AccuracyLayer cannot be used as a loss. + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < propagate_down.size(); ++i) { + if (propagate_down[i]) { NOT_IMPLEMENTED; } + } + } + + int top_k_; +}; + +/** + * @brief An interface for Layer%s that take two Blob%s as input -- usually + * (1) predictions and (2) ground-truth labels -- and output a + * singleton Blob representing the loss. + * + * LossLayers are typically only capable of backpropagating to their first input + * -- the predictions. + */ +template +class LossLayer : public Layer { + public: + explicit LossLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp( + const vector*>& bottom, const vector*>& top); + virtual void Reshape( + const vector*>& bottom, const vector*>& top); + + virtual inline int ExactNumBottomBlobs() const { return 2; } + + /** + * @brief For convenience and backwards compatibility, instruct the Net to + * automatically allocate a single top Blob for LossLayers, into which + * they output their singleton loss, (even if the user didn't specify + * one in the prototxt, etc.). + */ + virtual inline bool AutoTopBlobs() const { return true; } + virtual inline int ExactNumTopBlobs() const { return 1; } + /** + * We usually cannot backpropagate to the labels; ignore force_backward for + * these inputs. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return bottom_index != 1; + } +}; + +/** + * @brief Computes the contrastive loss @f$ + * E = \frac{1}{2N} \sum\limits_{n=1}^N \left(y\right) d + + * \left(1-y\right) \max \left(margin-d, 0\right) + * @f$ where @f$ + * d = \left| \left| a_n - b_n \right| \right|_2^2 @f$. This can be + * used to train siamese networks. + * + * @param bottom input Blob vector (length 3) + * -# @f$ (N \times C \times 1 \times 1) @f$ + * the features @f$ a \in [-\infty, +\infty]@f$ + * -# @f$ (N \times C \times 1 \times 1) @f$ + * the features @f$ b \in [-\infty, +\infty]@f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the binary similarity @f$ s \in [0, 1]@f$ + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed contrastive loss: @f$ E = + * \frac{1}{2N} \sum\limits_{n=1}^N \left(y\right) d + + * \left(1-y\right) \max \left(margin-d, 0\right) + * @f$ where @f$ + * d = \left| \left| a_n - b_n \right| \right|_2^2 @f$. + * This can be used to train siamese networks. + */ +template +class ContrastiveLossLayer : public LossLayer { + public: + explicit ContrastiveLossLayer(const LayerParameter& param) + : LossLayer(param), diff_() {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline int ExactNumBottomBlobs() const { return 3; } + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_CONTRASTIVE_LOSS; + } + /** + * Unlike most loss layers, in the ContrastiveLossLayer we can backpropagate + * to the first two inputs. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return bottom_index != 2; + } + + protected: + /// @copydoc ContrastiveLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the Contrastive error gradient w.r.t. the inputs. + * + * Computes the gradients with respect to the two input vectors (bottom[0] and + * bottom[1]), but not the similarity label (bottom[2]). + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times 1 \times 1) @f$ + * the features @f$a@f$; Backward fills their diff with + * gradients if propagate_down[0] + * -# @f$ (N \times C \times 1 \times 1) @f$ + * the features @f$b@f$; Backward fills their diff with gradients if + * propagate_down[1] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob diff_; // cached for backward pass + Blob dist_sq_; // cached for backward pass + Blob diff_sq_; // tmp storage for gpu forward pass + Blob summer_vec_; // tmp storage for gpu forward pass +}; + +/** + * @brief Computes the Euclidean (L2) loss @f$ + * E = \frac{1}{2N} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n + * \right| \right|_2^2 @f$ for real-valued regression tasks. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{y} \in [-\infty, +\infty]@f$ + * -# @f$ (N \times C \times H \times W) @f$ + * the targets @f$ y \in [-\infty, +\infty]@f$ + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed Euclidean loss: @f$ E = + * \frac{1}{2n} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n + * \right| \right|_2^2 @f$ + * + * This can be used for least-squares regression tasks. An InnerProductLayer + * input to a EuclideanLossLayer exactly formulates a linear least squares + * regression problem. With non-zero weight decay the problem becomes one of + * ridge regression -- see src/caffe/test/test_sgd_solver.cpp for a concrete + * example wherein we check that the gradients computed for a Net with exactly + * this structure match hand-computed gradient formulas for ridge regression. + * + * (Note: Caffe, and SGD in general, is certainly \b not the best way to solve + * linear least squares problems! We use it only as an instructive example.) + */ +template +class EuclideanLossLayer : public LossLayer { + public: + explicit EuclideanLossLayer(const LayerParameter& param) + : LossLayer(param), diff_() {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_EUCLIDEAN_LOSS; + } + + /** + * Unlike most loss layers, in the EuclideanLossLayer we can backpropagate + * to both inputs -- override to return true and always allow force_backward. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return true; + } + + protected: + /// @copydoc EuclideanLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the Euclidean error gradient w.r.t. the inputs. + * + * Unlike other children of LossLayer, EuclideanLossLayer \b can compute + * gradients with respect to the label inputs bottom[1] (but still only will + * if propagate_down[1] is set, due to being produced by learnable parameters + * or if force_backward is set). In fact, this layer is "commutative" -- the + * result is the same regardless of the order of the two bottoms. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$\hat{y}@f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial \hat{y}} = + * \frac{1}{n} \sum\limits_{n=1}^N (\hat{y}_n - y_n) + * @f$ if propagate_down[0] + * -# @f$ (N \times C \times H \times W) @f$ + * the targets @f$y@f$; Backward fills their diff with gradients + * @f$ \frac{\partial E}{\partial y} = + * \frac{1}{n} \sum\limits_{n=1}^N (y_n - \hat{y}_n) + * @f$ if propagate_down[1] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob diff_; +}; + +/** + * @brief Computes the hinge loss for a one-of-many classification task. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ t @f$, a Blob with values in + * @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of + * the @f$ K = CHW @f$ classes. In an SVM, @f$ t @f$ is the result of + * taking the inner product @f$ X^T W @f$ of the D-dimensional features + * @f$ X \in \mathcal{R}^{D \times N} @f$ and the learned hyperplane + * parameters @f$ W \in \mathcal{R}^{D \times K} @f$, so a Net with just + * an InnerProductLayer (with num_output = D) providing predictions to a + * HingeLossLayer and no other learnable parameters or losses is + * equivalent to an SVM. + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels @f$ l @f$, an integer-valued Blob with values + * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ + * indicating the correct class label among the @f$ K @f$ classes + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed hinge loss: @f$ E = + * \frac{1}{N} \sum\limits_{n=1}^N \sum\limits_{k=1}^K + * [\max(0, 1 - \delta\{l_n = k\} t_{nk})] ^ p + * @f$, for the @f$ L^p @f$ norm + * (defaults to @f$ p = 1 @f$, the L1 norm; L2 norm, as in L2-SVM, + * is also available), and @f$ + * \delta\{\mathrm{condition}\} = \left\{ + * \begin{array}{lr} + * 1 & \mbox{if condition} \\ + * -1 & \mbox{otherwise} + * \end{array} \right. + * @f$ + * + * In an SVM, @f$ t \in \mathcal{R}^{N \times K} @f$ is the result of taking + * the inner product @f$ X^T W @f$ of the features + * @f$ X \in \mathcal{R}^{D \times N} @f$ + * and the learned hyperplane parameters + * @f$ W \in \mathcal{R}^{D \times K} @f$. So, a Net with just an + * InnerProductLayer (with num_output = @f$k@f$) providing predictions to a + * HingeLossLayer is equivalent to an SVM (assuming it has no other learned + * outside the InnerProductLayer and no other losses outside the + * HingeLossLayer). + */ +template +class HingeLossLayer : public LossLayer { + public: + explicit HingeLossLayer(const LayerParameter& param) + : LossLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_HINGE_LOSS; + } + + protected: + /// @copydoc HingeLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the hinge loss error gradient w.r.t. the predictions. + * + * Gradients cannot be computed with respect to the label inputs (bottom[1]), + * so this method ignores bottom[1] and requires !propagate_down[1], crashing + * if propagate_down[1] is set. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * propagate_down[1] must be false as we can't compute gradients with + * respect to the labels. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$t@f$; Backward computes diff + * @f$ \frac{\partial E}{\partial t} @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels -- ignored as we can't compute their error gradients + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +/** + * @brief A generalization of MultinomialLogisticLossLayer that takes an + * "information gain" (infogain) matrix specifying the "value" of all label + * pairs. + * + * Equivalent to the MultinomialLogisticLossLayer if the infogain matrix is the + * identity. + * + * @param bottom input Blob vector (length 2-3) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{p} @f$, a Blob with values in + * @f$ [0, 1] @f$ indicating the predicted probability of each of the + * @f$ K = CHW @f$ classes. Each prediction vector @f$ \hat{p}_n @f$ + * should sum to 1 as in a probability distribution: @f$ + * \forall n \sum\limits_{k=1}^K \hat{p}_{nk} = 1 @f$. + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels @f$ l @f$, an integer-valued Blob with values + * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ + * indicating the correct class label among the @f$ K @f$ classes + * -# @f$ (1 \times 1 \times K \times K) @f$ + * (\b optional) the infogain matrix @f$ H @f$. This must be provided as + * the third bottom blob input if not provided as the infogain_mat in the + * InfogainLossParameter. If @f$ H = I @f$, this layer is equivalent to the + * MultinomialLogisticLossLayer. + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed infogain multinomial logistic loss: @f$ E = + * \frac{-1}{N} \sum\limits_{n=1}^N H_{l_n} \log(\hat{p}_n) = + * \frac{-1}{N} \sum\limits_{n=1}^N \sum\limits_{k=1}^{K} H_{l_n,k} + * \log(\hat{p}_{n,k}) + * @f$, where @f$ H_{l_n} @f$ denotes row @f$l_n@f$ of @f$H@f$. + */ +template +class InfogainLossLayer : public LossLayer { + public: + explicit InfogainLossLayer(const LayerParameter& param) + : LossLayer(param), infogain_() {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + // InfogainLossLayer takes 2-3 bottom Blobs; if there are 3 the third should + // be the infogain matrix. (Otherwise the infogain matrix is loaded from a + // file specified by LayerParameter.) + virtual inline int ExactNumBottomBlobs() const { return -1; } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 3; } + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_INFOGAIN_LOSS; + } + + protected: + /// @copydoc InfogainLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the infogain loss error gradient w.r.t. the predictions. + * + * Gradients cannot be computed with respect to the label inputs (bottom[1]), + * so this method ignores bottom[1] and requires !propagate_down[1], crashing + * if propagate_down[1] is set. (The same applies to the infogain matrix, if + * provided as bottom[2] rather than in the layer_param.) + * + * @param top output Blob vector (length 1), providing the error gradient + * with respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * propagate_down[1] must be false as we can't compute gradients with + * respect to the labels (similarly for propagate_down[2] and the + * infogain matrix, if provided as bottom[2]) + * @param bottom input Blob vector (length 2-3) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{p} @f$; Backward computes diff + * @f$ \frac{\partial E}{\partial \hat{p}} @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels -- ignored as we can't compute their error gradients + * -# @f$ (1 \times 1 \times K \times K) @f$ + * (\b optional) the information gain matrix -- ignored as its error + * gradient computation is not implemented. + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob infogain_; +}; + + + + +/** + * @brief Computes the multinomial logistic loss for a one-of-many + * classification task, directly taking a predicted probability + * distribution as input. + * + * When predictions are not already a probability distribution, you should + * instead use the SoftmaxWithLossLayer, which maps predictions to a + * distribution using the SoftmaxLayer, before computing the multinomial + * logistic loss. The SoftmaxWithLossLayer should be preferred over separate + * SoftmaxLayer + MultinomialLogisticLossLayer + * as its gradient computation is more numerically stable. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{p} @f$, a Blob with values in + * @f$ [0, 1] @f$ indicating the predicted probability of each of the + * @f$ K = CHW @f$ classes. Each prediction vector @f$ \hat{p}_n @f$ + * should sum to 1 as in a probability distribution: @f$ + * \forall n \sum\limits_{k=1}^K \hat{p}_{nk} = 1 @f$. + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels @f$ l @f$, an integer-valued Blob with values + * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ + * indicating the correct class label among the @f$ K @f$ classes + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed multinomial logistic loss: @f$ E = + * \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n}) + * @f$ + */ +template +class MultinomialLogisticLossLayer : public LossLayer { + public: + explicit MultinomialLogisticLossLayer(const LayerParameter& param) + : LossLayer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS; + } + + protected: + /// @copydoc MultinomialLogisticLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the multinomial logistic loss error gradient w.r.t. the + * predictions. + * + * Gradients cannot be computed with respect to the label inputs (bottom[1]), + * so this method ignores bottom[1] and requires !propagate_down[1], crashing + * if propagate_down[1] is set. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * propagate_down[1] must be false as we can't compute gradients with + * respect to the labels. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{p} @f$; Backward computes diff + * @f$ \frac{\partial E}{\partial \hat{p}} @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels -- ignored as we can't compute their error gradients + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +/** + * @brief Computes the cross-entropy (logistic) loss @f$ + * E = \frac{-1}{n} \sum\limits_{n=1}^N \left[ + * p_n \log \hat{p}_n + + * (1 - p_n) \log(1 - \hat{p}_n) + * \right] + * @f$, often used for predicting targets interpreted as probabilities. + * + * This layer is implemented rather than separate + * SigmoidLayer + CrossEntropyLayer + * as its gradient computation is more numerically stable. + * At test time, this layer can be replaced simply by a SigmoidLayer. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the scores @f$ x \in [-\infty, +\infty]@f$, + * which this layer maps to probability predictions + * @f$ \hat{p}_n = \sigma(x_n) \in [0, 1] @f$ + * using the sigmoid function @f$ \sigma(.) @f$ (see SigmoidLayer). + * -# @f$ (N \times C \times H \times W) @f$ + * the targets @f$ y \in [0, 1] @f$ + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed cross-entropy loss: @f$ + * E = \frac{-1}{n} \sum\limits_{n=1}^N \left[ + * p_n \log \hat{p}_n + (1 - p_n) \log(1 - \hat{p}_n) + * \right] + * @f$ + */ +template +class SigmoidCrossEntropyLossLayer : public LossLayer { + public: + explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param) + : LossLayer(param), + sigmoid_layer_(new SigmoidLayer(param)), + sigmoid_output_(new Blob()) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SIGMOID_CROSS_ENTROPY_LOSS; + } + + protected: + /// @copydoc SigmoidCrossEntropyLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the sigmoid cross-entropy loss error gradient w.r.t. the + * predictions. + * + * Gradients cannot be computed with respect to the target inputs (bottom[1]), + * so this method ignores bottom[1] and requires !propagate_down[1], crashing + * if propagate_down[1] is set. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * propagate_down[1] must be false as gradient computation with respect + * to the targets is not implemented. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$x@f$; Backward computes diff + * @f$ \frac{\partial E}{\partial x} = + * \frac{1}{n} \sum\limits_{n=1}^N (\hat{p}_n - p_n) + * @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels -- ignored as we can't compute their error gradients + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// The internal SigmoidLayer used to map predictions to probabilities. + shared_ptr > sigmoid_layer_; + /// sigmoid_output stores the output of the SigmoidLayer. + shared_ptr > sigmoid_output_; + /// bottom vector holder to call the underlying SigmoidLayer::Forward + vector*> sigmoid_bottom_vec_; + /// top vector holder to call the underlying SigmoidLayer::Forward + vector*> sigmoid_top_vec_; +}; + +// Forward declare SoftmaxLayer for use in SoftmaxWithLossLayer. +template class SoftmaxLayer; + +/** + * @brief Computes the multinomial logistic loss for a one-of-many + * classification task, passing real-valued predictions through a + * softmax to get a probability distribution over classes. + * + * This layer should be preferred over separate + * SoftmaxLayer + MultinomialLogisticLossLayer + * as its gradient computation is more numerically stable. + * At test time, this layer can be replaced simply by a SoftmaxLayer. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ x @f$, a Blob with values in + * @f$ [-\infty, +\infty] @f$ indicating the predicted score for each of + * the @f$ K = CHW @f$ classes. This layer maps these scores to a + * probability distribution over classes using the softmax function + * @f$ \hat{p}_{nk} = \exp(x_{nk}) / + * \left[\sum_{k'} \exp(x_{nk'})\right] @f$ (see SoftmaxLayer). + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels @f$ l @f$, an integer-valued Blob with values + * @f$ l_n \in [0, 1, 2, ..., K - 1] @f$ + * indicating the correct class label among the @f$ K @f$ classes + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed cross-entropy classification loss: @f$ E = + * \frac{-1}{N} \sum\limits_{n=1}^N \log(\hat{p}_{n,l_n}) + * @f$, for softmax output class probabilites @f$ \hat{p} @f$ + */ +template +class SoftmaxWithLossLayer : public LossLayer { + public: + explicit SoftmaxWithLossLayer(const LayerParameter& param) + : LossLayer(param), + softmax_layer_(new SoftmaxLayer(param)) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SOFTMAX_LOSS; + } + virtual inline int ExactNumBottomBlobs() const { return -1; } + virtual inline int MinBottomBlobs() const { return 2; } + virtual inline int MaxBottomBlobs() const { return 3; } + virtual inline int ExactNumTopBlobs() const { return -1; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline int MaxTopBlobs() const { return 2; } + + protected: + /// @copydoc SoftmaxWithLossLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + // TODO(Yangqing): implement the GPU version of softmax. + + /** + * @brief Computes the softmax loss error gradient w.r.t. the predictions. + * + * Gradients cannot be computed with respect to the label inputs (bottom[1]), + * so this method ignores bottom[1] and requires !propagate_down[1], crashing + * if propagate_down[1] is set. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * propagate_down[1] must be false as we can't compute gradients with + * respect to the labels. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ x @f$; Backward computes diff + * @f$ \frac{\partial E}{\partial x} @f$ + * -# @f$ (N \times 1 \times 1 \times 1) @f$ + * the labels -- ignored as we can't compute their error gradients + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// The internal SoftmaxLayer used to map predictions to a distribution. + shared_ptr > softmax_layer_; + /// prob stores the output probability predictions from the SoftmaxLayer. + Blob prob_; + /// bottom vector holder used in call to the underlying SoftmaxLayer::Forward + vector*> softmax_bottom_vec_; + /// top vector holder used in call to the underlying SoftmaxLayer::Forward + vector*> softmax_top_vec_; + /// Whether to ignore instances with a certain label. + bool has_ignore_label_; + /// The label indicating that an instance should be ignored. + int ignore_label_; + /// Whether to normalize the loss by the total number of values present + /// (otherwise just by the batch size). + bool normalize_; +}; + +} // namespace caffe + +#endif // CAFFE_LOSS_LAYERS_HPP_ diff --git a/caffe-crfrnn/include/caffe/net.hpp b/caffe-crfrnn/include/caffe/net.hpp new file mode 100644 index 00000000..1d06dc45 --- /dev/null +++ b/caffe-crfrnn/include/caffe/net.hpp @@ -0,0 +1,234 @@ +#ifndef CAFFE_NET_HPP_ +#define CAFFE_NET_HPP_ + +#include +#include +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Connects Layer%s together into a directed acyclic graph (DAG) + * specified by a NetParameter. + * + * TODO(dox): more thorough description. + */ +template +class Net { + public: + explicit Net(const NetParameter& param); + explicit Net(const string& param_file); + virtual ~Net() {} + + /// @brief Initialize a network with a NetParameter. + void Init(const NetParameter& param); + + /** + * @brief Run Forward with the input Blob%s already fed separately. + * + * You can get the input blobs using input_blobs(). + */ + const vector*>& ForwardPrefilled(Dtype* loss = NULL); + + /** + * The From and To variants of Forward and Backward operate on the + * (topological) ordering by which the net is specified. For general DAG + * networks, note that (1) computing from one layer to another might entail + * extra computation on unrelated branches, and (2) computation starting in + * the middle may be incorrect if all of the layers of a fan-in are not + * included. + */ + Dtype ForwardFromTo(int start, int end); + Dtype ForwardFrom(int start); + Dtype ForwardTo(int end); + /// @brief Run forward using a set of bottom blobs, and return the result. + const vector*>& Forward(const vector* > & bottom, + Dtype* loss = NULL); + /** + * @brief Run forward using a serialized BlobProtoVector and return the + * result as a serialized BlobProtoVector + */ + string Forward(const string& input_blob_protos, Dtype* loss = NULL); + + /** + * The network backward should take no input and output, since it solely + * computes the gradient w.r.t the parameters, and the data has already been + * provided during the forward pass. + */ + void Backward(); + void BackwardFromTo(int start, int end); + void BackwardFrom(int start); + void BackwardTo(int end); + + /** + * @brief Reshape all layers from bottom to top. + * + * This is useful to propagate changes to layer sizes without running + * a forward pass, e.g. to compute output feature size. + */ + void Reshape(); + + Dtype ForwardBackward(const vector* > & bottom) { + Dtype loss; + Forward(bottom, &loss); + Backward(); + return loss; + } + + /// @brief Updates the network weights based on the diff values computed. + void Update(); + + /** + * @brief For an already initialized net, implicitly copies (i.e., using no + * additional memory) the pre-trained layers from another Net. + */ + void ShareTrainedLayersWith(Net* other); + // For an already initialized net, CopyTrainedLayersFrom() copies the already + // trained layers from another net parameter instance. + /** + * @brief For an already initialized net, copies the pre-trained layers from + * another Net. + */ + void CopyTrainedLayersFrom(const NetParameter& param); + void CopyTrainedLayersFrom(const string trained_filename); + /// @brief Writes the net to a proto. + void ToProto(NetParameter* param, bool write_diff = false); + + /// @brief returns the network name. + inline const string& name() { return name_; } + /// @brief returns the layer names + inline const vector& layer_names() { return layer_names_; } + /// @brief returns the blob names + inline const vector& blob_names() { return blob_names_; } + /// @brief returns the blobs + inline const vector > >& blobs() { return blobs_; } + /// @brief returns the layers + inline const vector > >& layers() { return layers_; } + /** + * @brief returns the bottom vecs for each layer -- usually you won't + * need this unless you do per-layer checks such as gradients. + */ + inline vector*> >& bottom_vecs() { return bottom_vecs_; } + /** + * @brief returns the top vecs for each layer -- usually you won't + * need this unless you do per-layer checks such as gradients. + */ + inline vector*> >& top_vecs() { return top_vecs_; } + inline vector >& bottom_need_backward() { + return bottom_need_backward_; + } + inline vector& blob_loss_weights() { + return blob_loss_weights_; + } + /// @brief returns the parameters + inline vector > >& params() { return params_; } + /// @brief returns the parameter learning rate multipliers + inline vector& params_lr() { return params_lr_; } + inline vector& params_weight_decay() { return params_weight_decay_; } + const map& param_names_index() { return param_names_index_; } + /// @brief Input and output blob numbers + inline int num_inputs() { return net_input_blobs_.size(); } + inline int num_outputs() { return net_output_blobs_.size(); } + inline vector*>& input_blobs() { return net_input_blobs_; } + inline vector*>& output_blobs() { return net_output_blobs_; } + inline vector& input_blob_indices() { return net_input_blob_indices_; } + inline vector& output_blob_indices() { return net_output_blob_indices_; } + bool has_blob(const string& blob_name); + const shared_ptr > blob_by_name(const string& blob_name); + bool has_layer(const string& layer_name); + const shared_ptr > layer_by_name(const string& layer_name); + + void set_debug_info(const bool value) { debug_info_ = value; } + + // Helpers for Init. + /** + * @brief Remove layers that the user specified should be excluded given the current + * phase, level, and stage. + */ + static void FilterNet(const NetParameter& param, + NetParameter* param_filtered); + /// @brief return whether NetState state meets NetStateRule rule + static bool StateMeetsRule(const NetState& state, const NetStateRule& rule, + const string& layer_name); + + protected: + // Helpers for Init. + /// @brief Append a new input or top blob to the net. + void AppendTop(const NetParameter& param, const int layer_id, + const int top_id, set* available_blobs, + map* blob_name_to_idx); + /// @brief Append a new bottom blob to the net. + int AppendBottom(const NetParameter& param, const int layer_id, + const int bottom_id, set* available_blobs, + map* blob_name_to_idx); + /// @brief Append a new parameter blob to the net. + void AppendParam(const NetParameter& param, const int layer_id, + const int param_id); + + /// @brief Helper for displaying debug info in Forward. + void ForwardDebugInfo(const int layer_id); + /// @brief Helper for displaying debug info in Backward. + void BackwardDebugInfo(const int layer_id); + /// @brief Helper for displaying debug info in Update. + void UpdateDebugInfo(const int param_id); + + /// @brief Get misc parameters, e.g. the LR multiplier and weight decay. + void GetLearningRateAndWeightDecay(); + + /// @brief Individual layers in the net + vector > > layers_; + vector layer_names_; + map layer_names_index_; + vector layer_need_backward_; + /// @brief the blobs storing intermediate results between the layer. + vector > > blobs_; + vector blob_names_; + map blob_names_index_; + vector blob_need_backward_; + /// bottom_vecs stores the vectors containing the input for each layer. + /// They don't actually host the blobs (blobs_ does), so we simply store + /// pointers. + vector*> > bottom_vecs_; + vector > bottom_id_vecs_; + vector > bottom_need_backward_; + /// top_vecs stores the vectors containing the output for each layer + vector*> > top_vecs_; + vector > top_id_vecs_; + /// Vector of weight in the loss (or objective) function of each net blob, + /// indexed by blob_id. + vector blob_loss_weights_; + vector param_owners_; + vector param_display_names_; + vector > param_layer_indices_; + map param_names_index_; + /// blob indices for the input and the output of the net + vector net_input_blob_indices_; + vector net_output_blob_indices_; + vector*> net_input_blobs_; + vector*> net_output_blobs_; + string name_; + /// The parameters in the network. + vector > > params_; + /// the learning rate multipliers + vector params_lr_; + /// the weight decay multipliers + vector params_weight_decay_; + /// The bytes of memory used by this net + size_t memory_used_; + /// Whether to compute and display debug info for the net. + bool debug_info_; + + DISABLE_COPY_AND_ASSIGN(Net); +}; + + +} // namespace caffe + +#endif // CAFFE_NET_HPP_ diff --git a/caffe-crfrnn/include/caffe/neuron_layers.hpp b/caffe-crfrnn/include/caffe/neuron_layers.hpp new file mode 100644 index 00000000..7b761bce --- /dev/null +++ b/caffe-crfrnn/include/caffe/neuron_layers.hpp @@ -0,0 +1,679 @@ +#ifndef CAFFE_NEURON_LAYERS_HPP_ +#define CAFFE_NEURON_LAYERS_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#define HDF5_DATA_DATASET_NAME "data" +#define HDF5_DATA_LABEL_NAME "label" + +namespace caffe { + +/** + * @brief An interface for layers that take one blob as input (@f$ x @f$) + * and produce one equally-sized blob as output (@f$ y @f$), where + * each element of the output depends only on the corresponding input + * element. + */ +template +class NeuronLayer : public Layer { + public: + explicit NeuronLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_NONE; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } +}; + +/** + * @brief Computes @f$ y = |x| @f$ + * + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ y = |x| @f$ + */ +template +class AbsValLayer : public NeuronLayer { + public: + explicit AbsValLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_ABSVAL; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + /// @copydoc AbsValLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the absolute value inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} = + * \mathrm{sign}(x) \frac{\partial E}{\partial y} + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +/** + * @brief Computes @f$ y = x + \log(1 + \exp(-x)) @f$ if @f$ x > 0 @f$; + * @f$ y = \log(1 + \exp(x)) @f$ otherwise. + * + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = \left\{ + * \begin{array}{ll} + * x + \log(1 + \exp(-x)) & \mbox{if } x > 0 \\ + * \log(1 + \exp(x)) & \mbox{otherwise} + * \end{array} \right. + * @f$ + */ +template +class BNLLLayer : public NeuronLayer { + public: + explicit BNLLLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_BNLL; + } + + protected: + /// @copydoc BNLLLayer + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the BNLL inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +/** + * @brief During training only, sets a random portion of @f$x@f$ to 0, adjusting + * the rest of the vector magnitude accordingly. + * + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ y = |x| @f$ + */ +template +class DropoutLayer : public NeuronLayer { + public: + /** + * @param param provides DropoutParameter dropout_param, + * with DropoutLayer options: + * - dropout_ratio (\b optional, default 0.5). + * Sets the probability @f$ p @f$ that any given unit is dropped. + */ + explicit DropoutLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_DROPOUT; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs. At training time, we have @f$ + * y_{\mbox{train}} = \left\{ + * \begin{array}{ll} + * \frac{x}{1 - p} & \mbox{if } u > p \\ + * 0 & \mbox{otherwise} + * \end{array} \right. + * @f$, where @f$ u \sim U(0, 1)@f$ is generated independently for each + * input at each iteration. At test time, we simply have + * @f$ y_{\mbox{test}} = \mathbb{E}[y_{\mbox{train}}] = x @f$. + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// when divided by UINT_MAX, the randomly generated values @f$u\sim U(0,1)@f$ + Blob rand_vec_; + /// the probability @f$ p @f$ of dropping any input + Dtype threshold_; + /// the scale for undropped inputs at train time @f$ 1 / (1 - p) @f$ + Dtype scale_; + unsigned int uint_thres_; +}; + +/** + * @brief Computes @f$ y = \gamma ^ {\alpha x + \beta} @f$, + * as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$, + * and base @f$ \gamma @f$. + */ +template +class ExpLayer : public NeuronLayer { + public: + /** + * @param param provides ExpParameter exp_param, + * with ExpLayer options: + * - scale (\b optional, default 1) the scale @f$ \alpha @f$ + * - shift (\b optional, default 0) the shift @f$ \beta @f$ + * - base (\b optional, default -1 for a value of @f$ e \approx 2.718 @f$) + * the base @f$ \gamma @f$ + */ + explicit ExpLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_EXP; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = \gamma ^ {\alpha x + \beta} + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the exp inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} = + * \frac{\partial E}{\partial y} y \alpha \log_e(gamma) + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Dtype inner_scale_, outer_scale_; +}; + +/** + * @brief Computes @f$ y = (\alpha x + \beta) ^ \gamma @f$, + * as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$, + * and power @f$ \gamma @f$. + */ +template +class PowerLayer : public NeuronLayer { + public: + /** + * @param param provides PowerParameter power_param, + * with PowerLayer options: + * - scale (\b optional, default 1) the scale @f$ \alpha @f$ + * - shift (\b optional, default 0) the shift @f$ \beta @f$ + * - power (\b optional, default 1) the power @f$ \gamma @f$ + */ + explicit PowerLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_POWER; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = (\alpha x + \beta) ^ \gamma + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the power inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} = + * \frac{\partial E}{\partial y} + * \alpha \gamma (\alpha x + \beta) ^ {\gamma - 1} = + * \frac{\partial E}{\partial y} + * \frac{\alpha \gamma y}{\alpha x + \beta} + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + /// @brief @f$ \gamma @f$ from layer_param_.power_param() + Dtype power_; + /// @brief @f$ \alpha @f$ from layer_param_.power_param() + Dtype scale_; + /// @brief @f$ \beta @f$ from layer_param_.power_param() + Dtype shift_; + /// @brief Result of @f$ \alpha \gamma @f$ + Dtype diff_scale_; +}; + +/** + * @brief Rectified Linear Unit non-linearity @f$ y = \max(0, x) @f$. + * The simple max is fast to compute, and the function does not saturate. + */ +template +class ReLULayer : public NeuronLayer { + public: + /** + * @param param provides ReLUParameter relu_param, + * with ReLULayer options: + * - negative_slope (\b optional, default 0). + * the value @f$ \nu @f$ by which negative values are multiplied. + */ + explicit ReLULayer(const LayerParameter& param) + : NeuronLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_RELU; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = \max(0, x) + * @f$ by default. If a non-zero negative_slope @f$ \nu @f$ is provided, + * the computed outputs are @f$ y = \max(0, x) + \nu \min(0, x) @f$. + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the ReLU inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} = \left\{ + * \begin{array}{lr} + * 0 & \mathrm{if} \; x \le 0 \\ + * \frac{\partial E}{\partial y} & \mathrm{if} \; x > 0 + * \end{array} \right. + * @f$ if propagate_down[0], by default. + * If a non-zero negative_slope @f$ \nu @f$ is provided, + * the computed gradients are @f$ + * \frac{\partial E}{\partial x} = \left\{ + * \begin{array}{lr} + * \nu \frac{\partial E}{\partial y} & \mathrm{if} \; x \le 0 \\ + * \frac{\partial E}{\partial y} & \mathrm{if} \; x > 0 + * \end{array} \right. + * @f$. + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +#ifdef USE_CUDNN +/** + * @brief CuDNN acceleration of ReLULayer. + */ +template +class CuDNNReLULayer : public ReLULayer { + public: + explicit CuDNNReLULayer(const LayerParameter& param) + : ReLULayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNReLULayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; +}; +#endif + +/** + * @brief Sigmoid function non-linearity @f$ + * y = (1 + \exp(-x))^{-1} + * @f$, a classic choice in neural networks. + * + * Note that the gradient vanishes as the values move away from 0. + * The ReLULayer is often a better choice for this reason. + */ +template +class SigmoidLayer : public NeuronLayer { + public: + explicit SigmoidLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_SIGMOID; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = (1 + \exp(-x))^{-1} + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the sigmoid inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} + * = \frac{\partial E}{\partial y} y (1 - y) + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +#ifdef USE_CUDNN +/** + * @brief CuDNN acceleration of SigmoidLayer. + */ +template +class CuDNNSigmoidLayer : public SigmoidLayer { + public: + explicit CuDNNSigmoidLayer(const LayerParameter& param) + : SigmoidLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNSigmoidLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; +}; +#endif + +/** + * @brief TanH hyperbolic tangent non-linearity @f$ + * y = \frac{\exp(2x) - 1}{\exp(2x) + 1} + * @f$, popular in auto-encoders. + * + * Note that the gradient vanishes as the values move away from 0. + * The ReLULayer is often a better choice for this reason. + */ +template +class TanHLayer : public NeuronLayer { + public: + explicit TanHLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_TANH; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = \frac{\exp(2x) - 1}{\exp(2x) + 1} + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + + /** + * @brief Computes the error gradient w.r.t. the sigmoid inputs. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (N \times C \times H \times W) @f$ + * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ + * with respect to computed outputs @f$ y @f$ + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial x} + * = \frac{\partial E}{\partial y} + * \left(1 - \left[\frac{\exp(2x) - 1}{exp(2x) + 1} \right]^2 \right) + * = \frac{\partial E}{\partial y} (1 - y^2) + * @f$ if propagate_down[0] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); +}; + +#ifdef USE_CUDNN +/** + * @brief CuDNN acceleration of TanHLayer. + */ +template +class CuDNNTanHLayer : public TanHLayer { + public: + explicit CuDNNTanHLayer(const LayerParameter& param) + : TanHLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNTanHLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; +}; +#endif + +/** + * @brief Tests whether the input exceeds a threshold: outputs 1 for inputs + * above threshold; 0 otherwise. + */ +template +class ThresholdLayer : public NeuronLayer { + public: + /** + * @param param provides ThresholdParameter threshold_param, + * with ThresholdLayer options: + * - threshold (\b optional, default 0). + * the threshold value @f$ t @f$ to which the input values are compared. + */ + explicit ThresholdLayer(const LayerParameter& param) + : NeuronLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_THRESHOLD; + } + + protected: + /** + * @param bottom input Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the inputs @f$ x @f$ + * @param top output Blob vector (length 1) + * -# @f$ (N \times C \times H \times W) @f$ + * the computed outputs @f$ + * y = \left\{ + * \begin{array}{lr} + * 0 & \mathrm{if} \; x \le t \\ + * 1 & \mathrm{if} \; x > t + * \end{array} \right. + * @f$ + */ + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + /// @brief Not implemented (non-differentiable function) + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + NOT_IMPLEMENTED; + } + + Dtype threshold_; +}; + +} // namespace caffe + +#endif // CAFFE_NEURON_LAYERS_HPP_ diff --git a/caffe-crfrnn/include/caffe/solver.hpp b/caffe-crfrnn/include/caffe/solver.hpp new file mode 100644 index 00000000..fde66208 --- /dev/null +++ b/caffe-crfrnn/include/caffe/solver.hpp @@ -0,0 +1,146 @@ +#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ +#define CAFFE_OPTIMIZATION_SOLVER_HPP_ + +#include +#include + +#include "caffe/net.hpp" + +namespace caffe { + +/** + * @brief An interface for classes that perform optimization on Net%s. + * + * Requires implementation of ComputeUpdateValue to compute a parameter update + * given the current state of the Net parameters. + */ +template +class Solver { + public: + explicit Solver(const SolverParameter& param); + explicit Solver(const string& param_file); + void Init(const SolverParameter& param); + void InitTrainNet(); + void InitTestNets(); + // The main entry of the solver function. In default, iter will be zero. Pass + // in a non-zero iter number to resume training for a pre-trained net. + virtual void Solve(const char* resume_file = NULL); + inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } + void Step(int iters); + virtual ~Solver() {} + inline shared_ptr > net() { return net_; } + inline const vector > >& test_nets() { + return test_nets_; + } + int iter() { return iter_; } + + protected: + // Get the update value for the current iteration. + virtual void ComputeUpdateValue() = 0; + // The Solver::Snapshot function implements the basic snapshotting utility + // that stores the learned net. You should implement the SnapshotSolverState() + // function that produces a SolverState protocol buffer that needs to be + // written to disk together with the learned net. + void Snapshot(); + // The test routine + void TestAll(); + void Test(const int test_net_id = 0); + virtual void SnapshotSolverState(SolverState* state) = 0; + // The Restore function implements how one should restore the solver to a + // previously snapshotted state. You should implement the RestoreSolverState() + // function that restores the state from a SolverState protocol buffer. + void Restore(const char* resume_file); + virtual void RestoreSolverState(const SolverState& state) = 0; + void DisplayOutputBlobs(const int net_id); + + SolverParameter param_; + int iter_; + int current_step_; + shared_ptr > net_; + vector > > test_nets_; + + DISABLE_COPY_AND_ASSIGN(Solver); +}; + + +/** + * @brief Optimizes the parameters of a Net using + * stochastic gradient descent (SGD) with momentum. + */ +template +class SGDSolver : public Solver { + public: + explicit SGDSolver(const SolverParameter& param) + : Solver(param) { PreSolve(); } + explicit SGDSolver(const string& param_file) + : Solver(param_file) { PreSolve(); } + + const vector > >& history() { return history_; } + + protected: + void PreSolve(); + Dtype GetLearningRate(); + virtual void ComputeUpdateValue(); + virtual void SnapshotSolverState(SolverState * state); + virtual void RestoreSolverState(const SolverState& state); + // history maintains the historical momentum data. + // update maintains update related data and is not needed in snapshots. + // temp maintains other information that might be needed in computation + // of gradients/updates and is not needed in snapshots + vector > > history_, update_, temp_; + + DISABLE_COPY_AND_ASSIGN(SGDSolver); +}; + +template +class NesterovSolver : public SGDSolver { + public: + explicit NesterovSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit NesterovSolver(const string& param_file) + : SGDSolver(param_file) {} + + protected: + virtual void ComputeUpdateValue(); + + DISABLE_COPY_AND_ASSIGN(NesterovSolver); +}; + +template +class AdaGradSolver : public SGDSolver { + public: + explicit AdaGradSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit AdaGradSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with AdaGrad."; + } + + DISABLE_COPY_AND_ASSIGN(AdaGradSolver); +}; + +template +Solver* GetSolver(const SolverParameter& param) { + SolverParameter_SolverType type = param.solver_type(); + + switch (type) { + case SolverParameter_SolverType_SGD: + return new SGDSolver(param); + case SolverParameter_SolverType_NESTEROV: + return new NesterovSolver(param); + case SolverParameter_SolverType_ADAGRAD: + return new AdaGradSolver(param); + default: + LOG(FATAL) << "Unknown SolverType: " << type; + } + return (Solver*) NULL; +} + +} // namespace caffe + +#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ diff --git a/caffe-crfrnn/include/caffe/syncedmem.hpp b/caffe-crfrnn/include/caffe/syncedmem.hpp new file mode 100644 index 00000000..2564e071 --- /dev/null +++ b/caffe-crfrnn/include/caffe/syncedmem.hpp @@ -0,0 +1,73 @@ +#ifndef CAFFE_SYNCEDMEM_HPP_ +#define CAFFE_SYNCEDMEM_HPP_ + +#include + +#include "caffe/common.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +// Theoretically, CaffeMallocHost and CaffeFreeHost should simply call the +// cudaMallocHost and cudaFree functions in order to create pinned memory. +// However, those codes rely on the existence of a cuda GPU (I don't know +// why that is a must since allocating memory should not be accessing the +// GPU resorce, but it just creates an error as of Cuda 5.0) and will cause +// problem when running on a machine without GPU. Thus, we simply define +// these two functions for safety and possible future change if the problem +// of calling cuda functions disappears in a future version. +// +// In practice, although we are creating unpinned memory here, as long as we +// are constantly accessing them the memory pages almost always stays in +// the physical memory (assuming we have large enough memory installed), and +// does not seem to create a memory bottleneck here. + +inline void CaffeMallocHost(void** ptr, size_t size) { + *ptr = malloc(size); + CHECK(*ptr) << "host allocation of size " << size << " failed"; +} + +inline void CaffeFreeHost(void* ptr) { + free(ptr); +} + + +/** + * @brief Manages memory allocation and synchronization between the host (CPU) + * and device (GPU). + * + * TODO(dox): more thorough description. + */ +class SyncedMemory { + public: + SyncedMemory() + : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), + own_cpu_data_(false) {} + explicit SyncedMemory(size_t size) + : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), + own_cpu_data_(false) {} + ~SyncedMemory(); + const void* cpu_data(); + void set_cpu_data(void* data); + const void* gpu_data(); + void* mutable_cpu_data(); + void* mutable_gpu_data(); + enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; + SyncedHead head() { return head_; } + size_t size() { return size_; } + + private: + void to_cpu(); + void to_gpu(); + void* cpu_ptr_; + void* gpu_ptr_; + size_t size_; + SyncedHead head_; + bool own_cpu_data_; + + DISABLE_COPY_AND_ASSIGN(SyncedMemory); +}; // class SyncedMemory + +} // namespace caffe + +#endif // CAFFE_SYNCEDMEM_HPP_ diff --git a/caffe-crfrnn/include/caffe/test/test_caffe_main.hpp b/caffe-crfrnn/include/caffe/test/test_caffe_main.hpp new file mode 100644 index 00000000..438acf2b --- /dev/null +++ b/caffe-crfrnn/include/caffe/test/test_caffe_main.hpp @@ -0,0 +1,76 @@ +// The main caffe test code. Your test cpp code should include this hpp +// to allow a main function to be compiled into the binary. +#ifndef CAFFE_TEST_TEST_CAFFE_MAIN_HPP_ +#define CAFFE_TEST_TEST_CAFFE_MAIN_HPP_ + +#include +#include + +#include +#include + +#include "caffe/common.hpp" + +using std::cout; +using std::endl; + +#ifdef CMAKE_BUILD + #include +#else + #define CUDA_TEST_DEVICE -1 + #define CMAKE_SOURCE_DIR "src/" + #define EXAMPLES_SOURCE_DIR "examples/" + #define CMAKE_EXT "" +#endif + +int main(int argc, char** argv); + +namespace caffe { + +template +class MultiDeviceTest : public ::testing::Test { + public: + typedef typename TypeParam::Dtype Dtype; + protected: + MultiDeviceTest() { + Caffe::set_mode(TypeParam::device); + } + virtual ~MultiDeviceTest() {} +}; + +typedef ::testing::Types TestDtypes; + +struct FloatCPU { + typedef float Dtype; + static const Caffe::Brew device = Caffe::CPU; +}; + +struct DoubleCPU { + typedef double Dtype; + static const Caffe::Brew device = Caffe::CPU; +}; + +#ifdef CPU_ONLY + +typedef ::testing::Types TestDtypesAndDevices; + +#else + +struct FloatGPU { + typedef float Dtype; + static const Caffe::Brew device = Caffe::GPU; +}; + +struct DoubleGPU { + typedef double Dtype; + static const Caffe::Brew device = Caffe::GPU; +}; + +typedef ::testing::Types + TestDtypesAndDevices; + +#endif + +} // namespace caffe + +#endif // CAFFE_TEST_TEST_CAFFE_MAIN_HPP_ diff --git a/caffe-crfrnn/include/caffe/test/test_gradient_check_util.hpp b/caffe-crfrnn/include/caffe/test/test_gradient_check_util.hpp new file mode 100644 index 00000000..cc5dcbad --- /dev/null +++ b/caffe-crfrnn/include/caffe/test/test_gradient_check_util.hpp @@ -0,0 +1,260 @@ +#ifndef CAFFE_TEST_GRADIENT_CHECK_UTIL_H_ +#define CAFFE_TEST_GRADIENT_CHECK_UTIL_H_ + +#include +#include + +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/net.hpp" + +namespace caffe { + +// The gradient checker adds a L2 normalization loss function on top of the +// top blobs, and checks the gradient. +template +class GradientChecker { + public: + // kink and kink_range specify an ignored nonsmooth region of the form + // kink - kink_range <= |feature value| <= kink + kink_range, + // which accounts for all nonsmoothness in use by caffe + GradientChecker(const Dtype stepsize, const Dtype threshold, + const unsigned int seed = 1701, const Dtype kink = 0., + const Dtype kink_range = -1) + : stepsize_(stepsize), threshold_(threshold), seed_(seed), + kink_(kink), kink_range_(kink_range) {} + // Checks the gradient of a layer, with provided bottom layers and top + // layers. + // Note that after the gradient check, we do not guarantee that the data + // stored in the layer parameters and the blobs are unchanged. + void CheckGradient(Layer* layer, const vector*>& bottom, + const vector*>& top, int check_bottom = -1) { + layer->SetUp(bottom, top); + CheckGradientSingle(layer, bottom, top, check_bottom, -1, -1); + } + void CheckGradientExhaustive(Layer* layer, + const vector*>& bottom, const vector*>& top, + int check_bottom = -1); + + // CheckGradientEltwise can be used to test layers that perform element-wise + // computation only (e.g., neuron layers) -- where (d y_i) / (d x_j) = 0 when + // i != j. + void CheckGradientEltwise(Layer* layer, + const vector*>& bottom, const vector*>& top); + + void CheckGradientSingle(Layer* layer, + const vector*>& bottom, const vector*>& top, + int check_bottom, int top_id, int top_data_id, bool element_wise = false); + + // Checks the gradient of a network. This network should not have any data + // layers or loss layers, since the function does not explicitly deal with + // such cases yet. All input blobs and parameter blobs are going to be + // checked, layer-by-layer to avoid numerical problems to accumulate. + void CheckGradientNet(const Net& net, + const vector*>& input); + + protected: + Dtype GetObjAndGradient(const Layer& layer, + const vector*>& top, int top_id = -1, int top_data_id = -1); + Dtype stepsize_; + Dtype threshold_; + unsigned int seed_; + Dtype kink_; + Dtype kink_range_; +}; + + +template +void GradientChecker::CheckGradientSingle(Layer* layer, + const vector*>& bottom, const vector*>& top, + int check_bottom, int top_id, int top_data_id, bool element_wise) { + if (element_wise) { + CHECK_EQ(0, layer->blobs().size()); + CHECK_LE(0, top_id); + CHECK_LE(0, top_data_id); + const int top_count = top[top_id]->count(); + for (int blob_id = 0; blob_id < bottom.size(); ++blob_id) { + CHECK_EQ(top_count, bottom[blob_id]->count()); + } + } + // First, figure out what blobs we need to check against, and zero init + // parameter blobs. + vector*> blobs_to_check; + vector propagate_down(bottom.size(), check_bottom < 0); + for (int i = 0; i < layer->blobs().size(); ++i) { + Blob* blob = layer->blobs()[i].get(); + caffe_set(blob->count(), static_cast(0), blob->mutable_cpu_diff()); + blobs_to_check.push_back(blob); + } + if (check_bottom < 0) { + for (int i = 0; i < bottom.size(); ++i) { + blobs_to_check.push_back(bottom[i]); + } + } else { + CHECK_LT(check_bottom, bottom.size()); + blobs_to_check.push_back(bottom[check_bottom]); + propagate_down[check_bottom] = true; + } + // Compute the gradient analytically using Backward + Caffe::set_random_seed(seed_); + // Ignore the loss from the layer (it's just the weighted sum of the losses + // from the top blobs, whose gradients we may want to test individually). + layer->Forward(bottom, top); + // Get additional loss from the objective + GetObjAndGradient(*layer, top, top_id, top_data_id); + layer->Backward(top, propagate_down, bottom); + // Store computed gradients for all checked blobs + vector > > + computed_gradient_blobs(blobs_to_check.size()); + for (int blob_id = 0; blob_id < blobs_to_check.size(); ++blob_id) { + Blob* current_blob = blobs_to_check[blob_id]; + computed_gradient_blobs[blob_id].reset(new Blob()); + computed_gradient_blobs[blob_id]->ReshapeLike(*current_blob); + const int count = blobs_to_check[blob_id]->count(); + const Dtype* diff = blobs_to_check[blob_id]->cpu_diff(); + Dtype* computed_gradients = + computed_gradient_blobs[blob_id]->mutable_cpu_data(); + caffe_copy(count, diff, computed_gradients); + } + // Compute derivative of top w.r.t. each bottom and parameter input using + // finite differencing. + // LOG(ERROR) << "Checking " << blobs_to_check.size() << " blobs."; + for (int blob_id = 0; blob_id < blobs_to_check.size(); ++blob_id) { + Blob* current_blob = blobs_to_check[blob_id]; + const Dtype* computed_gradients = + computed_gradient_blobs[blob_id]->cpu_data(); + // LOG(ERROR) << "Blob " << blob_id << ": checking " + // << current_blob->count() << " parameters."; + for (int feat_id = 0; feat_id < current_blob->count(); ++feat_id) { + // For an element-wise layer, we only need to do finite differencing to + // compute the derivative of top[top_id][top_data_id] w.r.t. + // bottom[blob_id][i] only for i == top_data_id. For any other + // i != top_data_id, we know the derivative is 0 by definition, and simply + // check that that's true. + Dtype estimated_gradient = 0; + Dtype positive_objective = 0; + Dtype negative_objective = 0; + if (!element_wise || (feat_id == top_data_id)) { + // Do finite differencing. + // Compute loss with stepsize_ added to input. + current_blob->mutable_cpu_data()[feat_id] += stepsize_; + Caffe::set_random_seed(seed_); + layer->Forward(bottom, top); + positive_objective = + GetObjAndGradient(*layer, top, top_id, top_data_id); + // Compute loss with stepsize_ subtracted from input. + current_blob->mutable_cpu_data()[feat_id] -= stepsize_ * 2; + Caffe::set_random_seed(seed_); + layer->Forward(bottom, top); + negative_objective = + GetObjAndGradient(*layer, top, top_id, top_data_id); + // Recover original input value. + current_blob->mutable_cpu_data()[feat_id] += stepsize_; + estimated_gradient = (positive_objective - negative_objective) / + stepsize_ / 2.; + } + Dtype computed_gradient = computed_gradients[feat_id]; + Dtype feature = current_blob->cpu_data()[feat_id]; + // LOG(ERROR) << "debug: " << current_blob->cpu_data()[feat_id] << " " + // << current_blob->cpu_diff()[feat_id]; + if (kink_ - kink_range_ > fabs(feature) + || fabs(feature) > kink_ + kink_range_) { + // We check relative accuracy, but for too small values, we threshold + // the scale factor by 1. + Dtype scale = std::max( + std::max(fabs(computed_gradient), fabs(estimated_gradient)), 1.); + EXPECT_NEAR(computed_gradient, estimated_gradient, threshold_ * scale) + << "debug: (top_id, top_data_id, blob_id, feat_id)=" + << top_id << "," << top_data_id << "," << blob_id << "," << feat_id + << "; feat = " << feature + << "; objective+ = " << positive_objective + << "; objective- = " << negative_objective; + } + // LOG(ERROR) << "Feature: " << current_blob->cpu_data()[feat_id]; + // LOG(ERROR) << "computed gradient: " << computed_gradient + // << " estimated_gradient: " << estimated_gradient; + } + } +} + +template +void GradientChecker::CheckGradientExhaustive(Layer* layer, + const vector*>& bottom, const vector*>& top, + int check_bottom) { + layer->SetUp(bottom, top); + CHECK_GT(top.size(), 0) << "Exhaustive mode requires at least one top blob."; + // LOG(ERROR) << "Exhaustive Mode."; + for (int i = 0; i < top.size(); ++i) { + // LOG(ERROR) << "Exhaustive: blob " << i << " size " << top[i]->count(); + for (int j = 0; j < top[i]->count(); ++j) { + // LOG(ERROR) << "Exhaustive: blob " << i << " data " << j; + CheckGradientSingle(layer, bottom, top, check_bottom, i, j); + } + } +} + +template +void GradientChecker::CheckGradientEltwise(Layer* layer, + const vector*>& bottom, const vector*>& top) { + layer->SetUp(bottom, top); + CHECK_GT(top.size(), 0) << "Eltwise mode requires at least one top blob."; + const int check_bottom = -1; + const bool element_wise = true; + for (int i = 0; i < top.size(); ++i) { + for (int j = 0; j < top[i]->count(); ++j) { + CheckGradientSingle(layer, bottom, top, check_bottom, i, j, element_wise); + } + } +} + +template +void GradientChecker::CheckGradientNet( + const Net& net, const vector*>& input) { + const vector > >& layers = net.layers(); + vector*> >& bottom_vecs = net.bottom_vecs(); + vector*> >& top_vecs = net.top_vecs(); + for (int i = 0; i < layers.size(); ++i) { + net.Forward(input); + LOG(ERROR) << "Checking gradient for " << layers[i]->layer_param().name(); + CheckGradientExhaustive(*(layers[i].get()), bottom_vecs[i], top_vecs[i]); + } +} + +template +Dtype GradientChecker::GetObjAndGradient(const Layer& layer, + const vector*>& top, int top_id, int top_data_id) { + Dtype loss = 0; + if (top_id < 0) { + // the loss will be half of the sum of squares of all outputs + for (int i = 0; i < top.size(); ++i) { + Blob* top_blob = top[i]; + const Dtype* top_blob_data = top_blob->cpu_data(); + Dtype* top_blob_diff = top_blob->mutable_cpu_diff(); + int count = top_blob->count(); + for (int j = 0; j < count; ++j) { + loss += top_blob_data[j] * top_blob_data[j]; + } + // set the diff: simply the data. + caffe_copy(top_blob->count(), top_blob_data, top_blob_diff); + } + loss /= 2.; + } else { + // the loss will be the top_data_id-th element in the top_id-th blob. + for (int i = 0; i < top.size(); ++i) { + Blob* top_blob = top[i]; + Dtype* top_blob_diff = top_blob->mutable_cpu_diff(); + caffe_set(top_blob->count(), Dtype(0), top_blob_diff); + } + const Dtype loss_weight = 2; + loss = top[top_id]->cpu_data()[top_data_id] * loss_weight; + top[top_id]->mutable_cpu_diff()[top_data_id] = loss_weight; + } + return loss; +} + +} // namespace caffe + +#endif // CAFFE_TEST_GRADIENT_CHECK_UTIL_H_ diff --git a/caffe-crfrnn/include/caffe/util/benchmark.hpp b/caffe-crfrnn/include/caffe/util/benchmark.hpp new file mode 100644 index 00000000..d6358277 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/benchmark.hpp @@ -0,0 +1,52 @@ +#ifndef CAFFE_UTIL_BENCHMARK_H_ +#define CAFFE_UTIL_BENCHMARK_H_ + +#include + +#include "caffe/util/device_alternate.hpp" + +namespace caffe { + +class Timer { + public: + Timer(); + virtual ~Timer(); + virtual void Start(); + virtual void Stop(); + virtual float MilliSeconds(); + virtual float MicroSeconds(); + virtual float Seconds(); + + inline bool initted() { return initted_; } + inline bool running() { return running_; } + inline bool has_run_at_least_once() { return has_run_at_least_once_; } + + protected: + void Init(); + + bool initted_; + bool running_; + bool has_run_at_least_once_; +#ifndef CPU_ONLY + cudaEvent_t start_gpu_; + cudaEvent_t stop_gpu_; +#endif + boost::posix_time::ptime start_cpu_; + boost::posix_time::ptime stop_cpu_; + float elapsed_milliseconds_; + float elapsed_microseconds_; +}; + +class CPUTimer : public Timer { + public: + explicit CPUTimer(); + virtual ~CPUTimer() {} + virtual void Start(); + virtual void Stop(); + virtual float MilliSeconds(); + virtual float MicroSeconds(); +}; + +} // namespace caffe + +#endif // CAFFE_UTIL_BENCHMARK_H_ diff --git a/caffe-crfrnn/include/caffe/util/coords.hpp b/caffe-crfrnn/include/caffe/util/coords.hpp new file mode 100644 index 00000000..5032fc60 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/coords.hpp @@ -0,0 +1,61 @@ +#ifndef CAFFE_UTIL_COORDS_H_ +#define CAFFE_UTIL_COORDS_H_ + +#include +#include +#include + +namespace caffe { + +template +class DiagonalAffineMap { + public: + explicit DiagonalAffineMap(const vector > coefs) + : coefs_(coefs) { } + static DiagonalAffineMap identity(const int nd) { + return DiagonalAffineMap(vector >(nd, make_pair(1, 0))); + } + + inline DiagonalAffineMap compose(const DiagonalAffineMap& other) const { + CHECK_EQ(coefs_.size(), other.coefs_.size()) + << "Attempt to compose DiagonalAffineMaps of different dimensions"; + DiagonalAffineMap out; + transform(coefs_.begin(), coefs_.end(), other.coefs_.begin(), + std::back_inserter(out.coefs_), &compose_coefs); + return out; + } + inline DiagonalAffineMap inv() const { + DiagonalAffineMap out; + transform(coefs_.begin(), coefs_.end(), std::back_inserter(out.coefs_), + &inv_coefs); + return out; + } + inline vector > coefs() { return coefs_; } + + private: + DiagonalAffineMap() { } + static inline pair compose_coefs(pair left, + pair right) { + return make_pair(left.first * right.first, + left.first * right.second + left.second); + } + static inline pair inv_coefs(pair coefs) { + return make_pair(1 / coefs.first, - coefs.second / coefs.first); + } + vector > coefs_; +}; + +template +DiagonalAffineMap FilterMap(const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w) { + vector > coefs; + coefs.push_back(make_pair(stride_h, + static_cast(kernel_h - 1) / 2 - pad_h)); + coefs.push_back(make_pair(stride_w, + static_cast(kernel_w - 1) / 2 - pad_w)); + return DiagonalAffineMap(coefs); +} + +} // namespace caffe + +#endif // CAFFE_UTIL_COORDS_H_ diff --git a/caffe-crfrnn/include/caffe/util/cudnn.hpp b/caffe-crfrnn/include/caffe/util/cudnn.hpp new file mode 100644 index 00000000..05b78851 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/cudnn.hpp @@ -0,0 +1,134 @@ +#ifndef CAFFE_UTIL_CUDNN_H_ +#define CAFFE_UTIL_CUDNN_H_ +#ifdef USE_CUDNN + +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +#define CUDNN_CHECK(condition) \ + do { \ + cudnnStatus_t status = condition; \ + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " "\ + << cudnnGetErrorString(status); \ + } while (0) + +inline const char* cudnnGetErrorString(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; + } + return "Unknown cudnn status"; +} + +namespace caffe { + +namespace cudnn { + +template class dataType; +template<> class dataType { + public: + static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + static float oneval, zeroval; + static const void *one, *zero; +}; +template<> class dataType { + public: + static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + static double oneval, zeroval; + static const void *one, *zero; +}; + +template +inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(desc)); +} + +template +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, + int n, int c, int h, int w, + int stride_n, int stride_c, int stride_h, int stride_w) { + CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type, + n, c, h, w, stride_n, stride_c, stride_h, stride_w)); +} + +template +inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc, + int n, int c, int h, int w) { + const int stride_w = 1; + const int stride_h = w * stride_w; + const int stride_c = h * stride_h; + const int stride_n = c * stride_c; + setTensor4dDesc(desc, n, c, h, w, + stride_n, stride_c, stride_h, stride_w); +} + +template +inline void createFilterDesc(cudnnFilterDescriptor_t* desc, + int n, int c, int h, int w) { + CUDNN_CHECK(cudnnCreateFilterDescriptor(desc)); + CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType::type, + n, c, h, w)); +} + +template +inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) { + CUDNN_CHECK(cudnnCreateConvolutionDescriptor(conv)); +} + +template +inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, + cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, + int pad_h, int pad_w, int stride_h, int stride_w) { + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, + pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); +} + +template +inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc, + PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode, + int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) { + switch (poolmethod) { + case PoolingParameter_PoolMethod_MAX: + *mode = CUDNN_POOLING_MAX; + break; + case PoolingParameter_PoolMethod_AVE: + *mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc)); + CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w, + pad_h, pad_w, stride_h, stride_w)); +} + +} // namespace cudnn + +} // namespace caffe + +#endif // USE_CUDNN +#endif // CAFFE_UTIL_CUDNN_H_ + + diff --git a/caffe-crfrnn/include/caffe/util/device_alternate.hpp b/caffe-crfrnn/include/caffe/util/device_alternate.hpp new file mode 100644 index 00000000..5a45691b --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/device_alternate.hpp @@ -0,0 +1,102 @@ +#ifndef CAFFE_UTIL_DEVICE_ALTERNATE_H_ +#define CAFFE_UTIL_DEVICE_ALTERNATE_H_ + +#ifdef CPU_ONLY // CPU-only Caffe. + +#include + +// Stub out GPU calls as unavailable. + +#define NO_GPU LOG(FATAL) << "CPU-only Mode: cannot make GPU call." + +#define STUB_GPU(classname) \ +template \ +void classname::Forward_gpu(const vector*>& bottom, \ + const vector*>& top) { NO_GPU; } \ +template \ +void classname::Backward_gpu(const vector*>& top, \ + const vector& propagate_down, \ + const vector*>& bottom) { NO_GPU; } \ + +#define STUB_GPU_FORWARD(classname, funcname) \ +template \ +void classname::funcname##_##gpu(const vector*>& bottom, \ + const vector*>& top) { NO_GPU; } \ + +#define STUB_GPU_BACKWARD(classname, funcname) \ +template \ +void classname::funcname##_##gpu(const vector*>& top, \ + const vector& propagate_down, \ + const vector*>& bottom) { NO_GPU; } \ + +#else // Normal GPU + CPU Caffe. + +#include +#include +#include +#include +#include // cuda driver types +#ifdef USE_CUDNN // cuDNN acceleration library. +#include "caffe/util/cudnn.hpp" +#endif + +// +// CUDA macros +// + +// CUDA: various checks for different function calls. +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ + } while (0) + +#define CUBLAS_CHECK(condition) \ + do { \ + cublasStatus_t status = condition; \ + CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \ + << caffe::cublasGetErrorString(status); \ + } while (0) + +#define CURAND_CHECK(condition) \ + do { \ + curandStatus_t status = condition; \ + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \ + << caffe::curandGetErrorString(status); \ + } while (0) + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +// CUDA: check for error after kernel execution and exit loudly if there is one. +#define CUDA_POST_KERNEL_CHECK CUDA_CHECK(cudaPeekAtLastError()) + +namespace caffe { + +// CUDA: library error reporting. +const char* cublasGetErrorString(cublasStatus_t error); +const char* curandGetErrorString(curandStatus_t error); + +// CUDA: thread number configuration. +// Use 1024 threads per block, which requires cuda sm_2x or above, +// or fall back to attempt compatibility (best of luck to you). +#if __CUDA_ARCH__ >= 200 + const int CAFFE_CUDA_NUM_THREADS = 1024; +#else + const int CAFFE_CUDA_NUM_THREADS = 512; +#endif + +// CUDA: number of blocks for threads. +inline int CAFFE_GET_BLOCKS(const int N) { + return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; +} + +} // namespace caffe + +#endif // CPU_ONLY + +#endif // CAFFE_UTIL_DEVICE_ALTERNATE_H_ diff --git a/caffe-crfrnn/include/caffe/util/im2col.hpp b/caffe-crfrnn/include/caffe/util/im2col.hpp new file mode 100644 index 00000000..0051e2fa --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/im2col.hpp @@ -0,0 +1,32 @@ +#ifndef _CAFFE_UTIL_IM2COL_HPP_ +#define _CAFFE_UTIL_IM2COL_HPP_ + +namespace caffe { + +template +void im2col_cpu(const Dtype* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_col); + +template +void col2im_cpu(const Dtype* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_im); + +template +void im2col_gpu(const Dtype* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_col); + +template +void col2im_gpu(const Dtype* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_im); + +} // namespace caffe + +#endif // CAFFE_UTIL_IM2COL_HPP_ diff --git a/caffe-crfrnn/include/caffe/util/insert_splits.hpp b/caffe-crfrnn/include/caffe/util/insert_splits.hpp new file mode 100644 index 00000000..446abb81 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/insert_splits.hpp @@ -0,0 +1,26 @@ +#ifndef _CAFFE_UTIL_INSERT_SPLITS_HPP_ +#define _CAFFE_UTIL_INSERT_SPLITS_HPP_ + +#include + +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +// Copy NetParameters with SplitLayers added to replace any shared bottom +// blobs with unique bottom blobs provided by the SplitLayer. +void InsertSplits(const NetParameter& param, NetParameter* param_split); + +void ConfigureSplitLayer(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_count, const float loss_weight, + LayerParameter* split_layer_param); + +string SplitLayerName(const string& layer_name, const string& blob_name, + const int blob_idx); + +string SplitBlobName(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_idx); + +} // namespace caffe + +#endif // CAFFE_UTIL_INSERT_SPLITS_HPP_ diff --git a/caffe-crfrnn/include/caffe/util/io.hpp b/caffe-crfrnn/include/caffe/util/io.hpp new file mode 100644 index 00000000..64df0155 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/io.hpp @@ -0,0 +1,184 @@ +#ifndef CAFFE_UTIL_IO_H_ +#define CAFFE_UTIL_IO_H_ + +#ifndef OSX +#include +#endif + +#include +#include + +#include "google/protobuf/message.h" +#include "hdf5.h" +#include "hdf5_hl.h" + +#include "caffe/blob.hpp" +#include "caffe/proto/caffe.pb.h" + +#define HDF5_NUM_DIMS 4 + +namespace caffe { + +using ::google::protobuf::Message; + +inline void MakeTempFilename(string* temp_filename) { + temp_filename->clear(); + *temp_filename = "/tmp/caffe_test.XXXXXX"; + char* temp_filename_cstr = new char[temp_filename->size() + 1]; + // NOLINT_NEXT_LINE(runtime/printf) + strcpy(temp_filename_cstr, temp_filename->c_str()); + int fd = mkstemp(temp_filename_cstr); + CHECK_GE(fd, 0) << "Failed to open a temporary file at: " << *temp_filename; + close(fd); + *temp_filename = temp_filename_cstr; + delete[] temp_filename_cstr; +} + +inline void MakeTempDir(string* temp_dirname) { + temp_dirname->clear(); + *temp_dirname = "/tmp/caffe_test.XXXXXX"; + char* temp_dirname_cstr = new char[temp_dirname->size() + 1]; + // NOLINT_NEXT_LINE(runtime/printf) + strcpy(temp_dirname_cstr, temp_dirname->c_str()); + char* mkdtemp_result = mkdtemp(temp_dirname_cstr); + CHECK(mkdtemp_result != NULL) + << "Failed to create a temporary directory at: " << *temp_dirname; + *temp_dirname = temp_dirname_cstr; + delete[] temp_dirname_cstr; +} + +bool ReadProtoFromTextFile(const char* filename, Message* proto); + +inline bool ReadProtoFromTextFile(const string& filename, Message* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +inline void ReadProtoFromTextFileOrDie(const char* filename, Message* proto) { + CHECK(ReadProtoFromTextFile(filename, proto)); +} + +inline void ReadProtoFromTextFileOrDie(const string& filename, Message* proto) { + ReadProtoFromTextFileOrDie(filename.c_str(), proto); +} + +void WriteProtoToTextFile(const Message& proto, const char* filename); +inline void WriteProtoToTextFile(const Message& proto, const string& filename) { + WriteProtoToTextFile(proto, filename.c_str()); +} + +bool ReadProtoFromBinaryFile(const char* filename, Message* proto); + +inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) { + return ReadProtoFromBinaryFile(filename.c_str(), proto); +} + +inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) { + CHECK(ReadProtoFromBinaryFile(filename, proto)); +} + +inline void ReadProtoFromBinaryFileOrDie(const string& filename, + Message* proto) { + ReadProtoFromBinaryFileOrDie(filename.c_str(), proto); +} + + +void WriteProtoToBinaryFile(const Message& proto, const char* filename); +inline void WriteProtoToBinaryFile( + const Message& proto, const string& filename) { + WriteProtoToBinaryFile(proto, filename.c_str()); +} + +bool ReadFileToDatum(const string& filename, const int label, Datum* datum); + +inline bool ReadFileToDatum(const string& filename, Datum* datum) { + return ReadFileToDatum(filename, -1, datum); +} + +bool ReadImageToDatum(const string& filename, const int label, + const int height, const int width, const bool is_color, Datum* datum); + +inline bool ReadImageToDatum(const string& filename, const int label, + const int height, const int width, Datum* datum) { + return ReadImageToDatum(filename, label, height, width, true, datum); +} + +inline bool ReadImageToDatum(const string& filename, const int label, + const bool is_color, Datum* datum) { + return ReadImageToDatum(filename, label, 0, 0, is_color, datum); +} + +inline bool ReadImageToDatum(const string& filename, const int label, + Datum* datum) { + return ReadImageToDatum(filename, label, 0, 0, true, datum); +} + +bool DecodeDatum(const int height, const int width, const bool is_color, + Datum* datum); + +inline bool DecodeDatum(const int height, const int width, Datum* datum) { + return DecodeDatum(height, width, true, datum); +} + +inline bool DecodeDatum(const bool is_color, Datum* datum) { + return DecodeDatum(0, 0, is_color, datum); +} + +inline bool DecodeDatum(Datum* datum) { + return DecodeDatum(0, 0, true, datum); +} + +#ifndef OSX +cv::Mat ReadImageToCVMat(const string& filename, + const int height, const int width, const bool is_color); + +inline cv::Mat ReadImageToCVMat(const string& filename, + const int height, const int width) { + return ReadImageToCVMat(filename, height, width, true); +} + +inline cv::Mat ReadImageToCVMat(const string& filename, + const bool is_color) { + return ReadImageToCVMat(filename, 0, 0, is_color); +} + +inline cv::Mat ReadImageToCVMat(const string& filename) { + return ReadImageToCVMat(filename, 0, 0, true); +} + +cv::Mat DecodeDatumToCVMat(const Datum& datum, + const int height, const int width, const bool is_color); + +inline cv::Mat DecodeDatumToCVMat(const Datum& datum, + const int height, const int width) { + return DecodeDatumToCVMat(datum, height, width, true); +} + +inline cv::Mat DecodeDatumToCVMat(const Datum& datum, + const bool is_color) { + return DecodeDatumToCVMat(datum, 0, 0, is_color); +} + +inline cv::Mat DecodeDatumToCVMat(const Datum& datum) { + return DecodeDatumToCVMat(datum, 0, 0, true); +} + +void CVMatToDatum(const cv::Mat& cv_img, Datum* datum); +#endif + +template +void hdf5_load_nd_dataset_helper( + hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, + Blob* blob); + +template +void hdf5_load_nd_dataset( + hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, + Blob* blob); + +template +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob); + +} // namespace caffe + +#endif // CAFFE_UTIL_IO_H_ diff --git a/caffe-crfrnn/include/caffe/util/math_functions.hpp b/caffe-crfrnn/include/caffe/util/math_functions.hpp new file mode 100644 index 00000000..d3ecf587 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/math_functions.hpp @@ -0,0 +1,274 @@ +#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_ +#define CAFFE_UTIL_MATH_FUNCTIONS_H_ + +#include +#include // for std::fabs and std::signbit + +#include "glog/logging.h" + +#include "caffe/common.hpp" +#include "caffe/util/device_alternate.hpp" +#include "caffe/util/mkl_alternate.hpp" + +namespace caffe { + +// Decaf gemm provides a simpler interface to the gemm functions, with the +// limitation that the data has to be contiguous in memory. +template +void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta, + Dtype* C); + +template +void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, + const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta, + Dtype* y); + +template +void caffe_axpy(const int N, const Dtype alpha, const Dtype* X, + Dtype* Y); + +template +void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X, + const Dtype beta, Dtype* Y); + +template +void caffe_copy(const int N, const Dtype *X, Dtype *Y); + +template +void caffe_set(const int N, const Dtype alpha, Dtype *X); + +inline void caffe_memset(const size_t N, const int alpha, void* X) { + memset(X, alpha, N); // NOLINT(caffe/alt_fn) +} + +template +void caffe_add_scalar(const int N, const Dtype alpha, Dtype *X); + +template +void caffe_scal(const int N, const Dtype alpha, Dtype *X); + +template +void caffe_sqr(const int N, const Dtype* a, Dtype* y); + +template +void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y); + +unsigned int caffe_rng_rand(); + +template +Dtype caffe_nextafter(const Dtype b); + +template +void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r); + +template +void caffe_rng_gaussian(const int n, const Dtype mu, const Dtype sigma, + Dtype* r); + +template +void caffe_rng_bernoulli(const int n, const Dtype p, int* r); + +template +void caffe_rng_bernoulli(const int n, const Dtype p, unsigned int* r); + +template +void caffe_exp(const int n, const Dtype* a, Dtype* y); + +template +void caffe_abs(const int n, const Dtype* a, Dtype* y); + +template +Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y); + +template +Dtype caffe_cpu_strided_dot(const int n, const Dtype* x, const int incx, + const Dtype* y, const int incy); + +template +int caffe_cpu_hamming_distance(const int n, const Dtype* x, const Dtype* y); + +// Returns the sum of the absolute values of the elements of vector x +template +Dtype caffe_cpu_asum(const int n, const Dtype* x); + +// the branchless, type-safe version from +// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c +template +inline int8_t caffe_sign(Dtype val) { + return (Dtype(0) < val) - (val < Dtype(0)); +} + +// The following two macros are modifications of DEFINE_VSL_UNARY_FUNC +// in include/caffe/util/mkl_alternate.hpp authored by @Rowland Depp. +// Please refer to commit 7e8ef25c7 of the boost-eigen branch. +// Git cherry picking that commit caused a conflict hard to resolve and +// copying that file in convenient for code reviewing. +// So they have to be pasted here temporarily. +#define DEFINE_CAFFE_CPU_UNARY_FUNC(name, operation) \ + template \ + void caffe_cpu_##name(const int n, const Dtype* x, Dtype* y) { \ + CHECK_GT(n, 0); CHECK(x); CHECK(y); \ + for (int i = 0; i < n; ++i) { \ + operation; \ + } \ + } + +// output is 1 for the positives, 0 for zero, and -1 for the negatives +DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign(x[i])); + +// This returns a nonzero value if the input has its sign bit set. +// The name sngbit is meant to avoid conflicts with std::signbit in the macro. +// The extra parens are needed because CUDA < 6.5 defines signbit as a macro, +// and we don't want that to expand here when CUDA headers are also included. +DEFINE_CAFFE_CPU_UNARY_FUNC(sgnbit, \ + y[i] = static_cast((std::signbit)(x[i]))); + +DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i])); + +template +void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); + +#ifndef CPU_ONLY // GPU + +// Decaf gpu gemm provides an interface that is almost the same as the cpu +// gemm function - following the c convention and calling the fortran-order +// gpu code under the hood. +template +void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta, + Dtype* C); + +template +void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, + const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta, + Dtype* y); + +template +void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X, + Dtype* Y); + +template +void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X, + const Dtype beta, Dtype* Y); + +void caffe_gpu_memcpy(const size_t N, const void *X, void *Y); + +template +void caffe_gpu_set(const int N, const Dtype alpha, Dtype *X); + +inline void caffe_gpu_memset(const size_t N, const int alpha, void* X) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaMemset(X, alpha, N)); // NOLINT(caffe/alt_fn) +#else + NO_GPU; +#endif +} + +template +void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype *X); + +template +void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X); + +template +void caffe_gpu_add(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_gpu_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_gpu_div(const int N, const Dtype* a, const Dtype* b, Dtype* y); + +template +void caffe_gpu_abs(const int n, const Dtype* a, Dtype* y); + +template +void caffe_gpu_exp(const int n, const Dtype* a, Dtype* y); + +template +void caffe_gpu_powx(const int n, const Dtype* a, const Dtype b, Dtype* y); + +// caffe_gpu_rng_uniform with two arguments generates integers in the range +// [0, UINT_MAX]. +void caffe_gpu_rng_uniform(const int n, unsigned int* r); + +// caffe_gpu_rng_uniform with four arguments generates floats in the range +// (a, b] (strictly greater than a, less than or equal to b) due to the +// specification of curandGenerateUniform. With a = 0, b = 1, just calls +// curandGenerateUniform; with other limits will shift and scale the outputs +// appropriately after calling curandGenerateUniform. +template +void caffe_gpu_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r); + +template +void caffe_gpu_rng_gaussian(const int n, const Dtype mu, const Dtype sigma, + Dtype* r); + +template +void caffe_gpu_rng_bernoulli(const int n, const Dtype p, int* r); + +template +void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out); + +template +uint32_t caffe_gpu_hamming_distance(const int n, const Dtype* x, + const Dtype* y); + +template +void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y); + +template +void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y); + +template +void caffe_gpu_sgnbit(const int n, const Dtype* x, Dtype* y); + +template +void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y); + +template +void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); + +#define DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(name, operation) \ +template \ +__global__ void name##_kernel(const int n, const Dtype* x, Dtype* y) { \ + CUDA_KERNEL_LOOP(index, n) { \ + operation; \ + } \ +} \ +template <> \ +void caffe_gpu_##name(const int n, const float* x, float* y) { \ + /* NOLINT_NEXT_LINE(whitespace/operators) */ \ + name##_kernel<<>>( \ + n, x, y); \ +} \ +template <> \ +void caffe_gpu_##name(const int n, const double* x, double* y) { \ + /* NOLINT_NEXT_LINE(whitespace/operators) */ \ + name##_kernel<<>>( \ + n, x, y); \ +} + +#endif // !CPU_ONLY + +} // namespace caffe + +#endif // CAFFE_UTIL_MATH_FUNCTIONS_H_ diff --git a/caffe-crfrnn/include/caffe/util/mkl_alternate.hpp b/caffe-crfrnn/include/caffe/util/mkl_alternate.hpp new file mode 100644 index 00000000..32fdbf79 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/mkl_alternate.hpp @@ -0,0 +1,96 @@ +#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_ +#define CAFFE_UTIL_MKL_ALTERNATE_H_ + +#ifdef USE_MKL + +#include + +#else // If use MKL, simply include the MKL header + +extern "C" { +#include +} +#include + +// Functions that caffe uses but are not present if MKL is not linked. + +// A simple way to define the vsl unary functions. The operation should +// be in the form e.g. y[i] = sqrt(a[i]) +#define DEFINE_VSL_UNARY_FUNC(name, operation) \ + template \ + void v##name(const int n, const Dtype* a, Dtype* y) { \ + CHECK_GT(n, 0); CHECK(a); CHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, float* y) { \ + v##name(n, a, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, double* y) { \ + v##name(n, a, y); \ + } + +DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]); +DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i])); +DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i])); + +// A simple way to define the vsl unary functions with singular parameter b. +// The operation should be in the form e.g. y[i] = pow(a[i], b) +#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \ + template \ + void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \ + CHECK_GT(n, 0); CHECK(a); CHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, const float b, float* y) { \ + v##name(n, a, b, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, const float b, double* y) { \ + v##name(n, a, b, y); \ + } + +DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b)); + +// A simple way to define the vsl binary functions. The operation should +// be in the form e.g. y[i] = a[i] + b[i] +#define DEFINE_VSL_BINARY_FUNC(name, operation) \ + template \ + void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \ + CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, const float* b, float* y) { \ + v##name(n, a, b, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, const double* b, double* y) { \ + v##name(n, a, b, y); \ + } + +DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]); +DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]); +DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]); +DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]); + +// In addition, MKL comes with an additional function axpby that is not present +// in standard blas. We will simply use a two-step (inefficient, of course) way +// to mimic that. +inline void cblas_saxpby(const int N, const float alpha, const float* X, + const int incX, const float beta, float* Y, + const int incY) { + cblas_sscal(N, beta, Y, incY); + cblas_saxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_daxpby(const int N, const double alpha, const double* X, + const int incX, const double beta, double* Y, + const int incY) { + cblas_dscal(N, beta, Y, incY); + cblas_daxpy(N, alpha, X, incX, Y, incY); +} + +#endif // USE_MKL +#endif // CAFFE_UTIL_MKL_ALTERNATE_H_ diff --git a/caffe-crfrnn/include/caffe/util/modified_permutohedral.hpp b/caffe-crfrnn/include/caffe/util/modified_permutohedral.hpp new file mode 100755 index 00000000..1ed48173 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/modified_permutohedral.hpp @@ -0,0 +1,40 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +/************************************************/ +/*** ModifiedPermutohedral Lattice ***/ +/************************************************/ +namespace caffe { + +class ModifiedPermutohedral +{ +protected: + struct Neighbors{ + int n1, n2; + Neighbors( int n1=0, int n2=0 ):n1(n1),n2(n2){ + } + }; + std::vector offset_, rank_; + std::vector barycentric_; + std::vector blur_neighbors_; + // Number of elements, size of sparse discretized space, dimension of features + int N_, M_, d_; + + void sseCompute(float* out, const float* in, int value_size, bool reverse = false, bool add = false) const; + void sseCompute(double* out, const double* in, int value_size, bool reverse = false, bool add = false) const; + + void seqCompute(float* out, const float* in, int value_size, bool reverse = false, bool add = false) const; + void seqCompute(double* out, const double* in, int value_size, bool reverse = false, bool add = false) const; + +public: + ModifiedPermutohedral(); + void init (const float* features, int num_dimensions, int num_points); + void compute(float* out, const float* in, int value_size, bool reverse = false, bool add = false) const; + void compute(double* out, const double* in, int value_size, bool reverse = false, bool add = false) const; +}; +} diff --git a/caffe-crfrnn/include/caffe/util/rng.hpp b/caffe-crfrnn/include/caffe/util/rng.hpp new file mode 100644 index 00000000..8f1cf0d1 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/rng.hpp @@ -0,0 +1,43 @@ +#ifndef CAFFE_RNG_CPP_HPP_ +#define CAFFE_RNG_CPP_HPP_ + +#include +#include + +#include "boost/random/mersenne_twister.hpp" +#include "boost/random/uniform_int.hpp" + +#include "caffe/common.hpp" + +namespace caffe { + +typedef boost::mt19937 rng_t; + +inline rng_t* caffe_rng() { + return static_cast(Caffe::rng_stream().generator()); +} + +// Fisher–Yates algorithm +template +inline void shuffle(RandomAccessIterator begin, RandomAccessIterator end, + RandomGenerator* gen) { + typedef typename std::iterator_traits::difference_type + difference_type; + typedef typename boost::uniform_int dist_type; + + difference_type length = std::distance(begin, end); + if (length <= 0) return; + + for (difference_type i = length - 1; i > 0; --i) { + dist_type dist(0, i); + std::iter_swap(begin + i, begin + dist(*gen)); + } +} + +template +inline void shuffle(RandomAccessIterator begin, RandomAccessIterator end) { + shuffle(begin, end, caffe_rng()); +} +} // namespace caffe + +#endif // CAFFE_RNG_HPP_ diff --git a/caffe-crfrnn/include/caffe/util/thread.hpp b/caffe-crfrnn/include/caffe/util/thread.hpp new file mode 100644 index 00000000..7251402c --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/thread.hpp @@ -0,0 +1,25 @@ +#ifndef CAFFE_THREAD_CPP_HPP_ +#define CAFFE_THREAD_CPP_HPP_ + +#include +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" + +namespace caffe { + +template +Thread::Thread(Callable func, A1 a1) { + this->thread_ = new boost::thread(func, a1); +} + +void Thread::join() { + static_cast(this->thread_)->join(); +} + +bool Thread::joinable() { + return static_cast(this->thread_)->joinable(); +} + +} // namespace caffe + +#endif diff --git a/caffe-crfrnn/include/caffe/util/tvg_util.hpp b/caffe-crfrnn/include/caffe/util/tvg_util.hpp new file mode 100644 index 00000000..c8bcfcb3 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/tvg_util.hpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include "caffe/blob.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void PrintBlob(const Blob* blob, bool print_diff = false, const char* info = 0) { + + const Dtype* data = print_diff ? blob->cpu_diff() : blob->cpu_data(); + + if (info != 0) { + printf("%s: \n", info); + } + + for (int n = 0; n < blob->num(); n++) { + for (int c = 0; c < blob->channels(); c++) { + for (int h = 0; h < blob->height(); h++) { + for (int w = 0; w < blob->width(); w++) { + int offset = ((n * blob->channels() + c) * blob->height() + h) * blob->width() + w; + printf("%11.6f ", *(data + offset)); + } + printf("\n"); + } + printf("\n"); + } + // printf("=================\n"); + } + + printf("-- End of Blob --\n\n"); +} + +template void PrintBlob(const Blob* blob, bool print_diff = false, const char* info = 0); + +template void PrintBlob(const Blob* blob, bool print_diff = false, const char* info = 0); + + +template +void FillWithMax(Blob* blob, float max_value = 1) { + + srand(2000); + + for (int i = 0; i < blob->count(); ++i) { + blob->mutable_cpu_data()[i] = ((double) rand() / RAND_MAX) * max_value; + } +} +template void FillWithMax(Blob* const blob, float max_value = 1); +template void FillWithMax(Blob* const blob, float max_value = 1); + + +template +void FillAsRGB(Blob* blob) { + + srand(2000); + + for (int i = 0; i < blob->count(); ++i) { + blob->mutable_cpu_data()[i] = rand() % 256; + } +} +template void FillAsRGB(Blob* const blob); +template void FillAsRGB(Blob* const blob); + +template +void FillAsProb(Blob* blob) { + + srand(1000);//time(NULL)); + + for (int i = 0; i < blob->count(); ++i) { + double num = (double) rand() / (double) RAND_MAX; + blob->mutable_cpu_data()[i] = static_cast((num != 0) ? num : 0.0002); + } + + for (int n = 0; n < blob->num(); ++n) { + for (int h = 0; h < blob->height(); ++h) { + for (int w = 0; w < blob->width(); ++w) { + + Dtype total = 0; + + for (int c = 0; c < blob->channels(); ++c) { + total += blob->data_at(n, c, h, w); + } + + for (int c = 0; c < blob->channels(); ++c) { + blob->mutable_cpu_data()[blob->offset(n, c, h, w)] = blob->data_at(n, c, h, w) / total; + } + } + } + } +} +template void FillAsProb(Blob* const blob); +template void FillAsProb(Blob* const blob); + + +template +void FillAsLogProb(Blob* blob) { + FillAsProb(blob); + + for (int i = 0; i < blob->count(); ++i) { + blob->mutable_cpu_data()[i] = log(blob->cpu_data()[i]); + } +} +template void FillAsLogProb(Blob* const blob); +template void FillAsLogProb(Blob* const blob); + +} diff --git a/caffe-crfrnn/include/caffe/util/upgrade_proto.hpp b/caffe-crfrnn/include/caffe/util/upgrade_proto.hpp new file mode 100644 index 00000000..45483685 --- /dev/null +++ b/caffe-crfrnn/include/caffe/util/upgrade_proto.hpp @@ -0,0 +1,55 @@ +#ifndef CAFFE_UTIL_UPGRADE_PROTO_H_ +#define CAFFE_UTIL_UPGRADE_PROTO_H_ + +#include + +#include "caffe/proto/caffe.pb.h" +#include "caffe/proto/caffe_pretty_print.pb.h" + +namespace caffe { + +// Return true iff any layer contains parameters specified using +// deprecated V0LayerParameter. +bool NetNeedsUpgrade(const NetParameter& net_param); + +// Perform all necessary transformations to upgrade a V0NetParameter into a +// NetParameter (including upgrading padding layers and LayerParameters). +bool UpgradeV0Net(const NetParameter& v0_net_param, NetParameter* net_param); + +// Upgrade NetParameter with padding layers to pad-aware conv layers. +// For any padding layer, remove it and put its pad parameter in any layers +// taking its top blob as input. +// Error if any of these above layers are not-conv layers. +void UpgradeV0PaddingLayers(const NetParameter& param, + NetParameter* param_upgraded_pad); + +// Upgrade a single V0LayerConnection to the new LayerParameter format. +bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, + LayerParameter* layer_param); + +LayerParameter_LayerType UpgradeV0LayerType(const string& type); + +// Return true iff any layer contains deprecated data transformation parameters. +bool NetNeedsDataUpgrade(const NetParameter& net_param); + +// Perform all necessary transformations to upgrade old transformation fields +// into a TransformationParameter. +void UpgradeNetDataTransformation(NetParameter* net_param); + +// Convert a NetParameter to NetParameterPrettyPrint used for dumping to +// proto text files. +void NetParameterToPrettyPrint(const NetParameter& param, + NetParameterPrettyPrint* pretty_param); + +// Check for deprecations and upgrade the NetParameter as needed. +void UpgradeNetAsNeeded(NetParameter* param); + +// Read parameters from a file into a NetParameter proto message. +void ReadNetParamsFromTextFileOrDie(const string& param_file, + NetParameter* param); +void ReadNetParamsFromBinaryFileOrDie(const string& param_file, + NetParameter* param); + +} // namespace caffe + +#endif // CAFFE_UTIL_UPGRADE_PROTO_H_ diff --git a/caffe-crfrnn/include/caffe/vision_layers.hpp b/caffe-crfrnn/include/caffe/vision_layers.hpp new file mode 100755 index 00000000..4c402a9f --- /dev/null +++ b/caffe-crfrnn/include/caffe/vision_layers.hpp @@ -0,0 +1,658 @@ +#ifndef CAFFE_VISION_LAYERS_HPP_ +#define CAFFE_VISION_LAYERS_HPP_ + +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/loss_layers.hpp" +#include "caffe/neuron_layers.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/util/modified_permutohedral.hpp" +#include + +namespace caffe { + +template +class BaseConvolutionLayer : public Layer { + public: + explicit BaseConvolutionLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline bool EqualNumBottomTopBlobs() const { return true; } + + protected: + // Helper functions that abstract away the column buffer and gemm arguments. + // The last argument in forward_cpu_gemm is so that we can skip the im2col if + // we just called weight_cpu_gemm with the same input. + void forward_cpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* output, bool skip_im2col = false); + void forward_cpu_bias(Dtype* output, const Dtype* bias); + void backward_cpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* output); + void weight_cpu_gemm(const Dtype* input, const Dtype* output, Dtype* + weights); + void backward_cpu_bias(Dtype* bias, const Dtype* input); + +#ifndef CPU_ONLY + void forward_gpu_gemm(const Dtype* col_input, const Dtype* weights, + Dtype* output, bool skip_im2col = false); + void forward_gpu_bias(Dtype* output, const Dtype* bias); + void backward_gpu_gemm(const Dtype* input, const Dtype* weights, + Dtype* col_output); + void weight_gpu_gemm(const Dtype* col_input, const Dtype* output, Dtype* + weights); + void backward_gpu_bias(Dtype* bias, const Dtype* input); +#endif + + // reverse_dimensions should return true iff we are implementing deconv, so + // that conv helpers know which dimensions are which. + virtual bool reverse_dimensions() = 0; + // Compute height_out_ and width_out_ from other parameters. + virtual void compute_output_shape() = 0; + + int kernel_h_, kernel_w_; + int stride_h_, stride_w_; + int num_; + int channels_; + int pad_h_, pad_w_; + int height_, width_; + int group_; + int num_output_; + int height_out_, width_out_; + bool bias_term_; + bool is_1x1_; + + private: + // wrap im2col/col2im so we don't have to remember the (long) argument lists + inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) { + im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + } + inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { + col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + } +#ifndef CPU_ONLY + inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) { + im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff); + } + inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { + col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data); + } +#endif + + int conv_out_channels_; + int conv_in_channels_; + int conv_out_spatial_dim_; + int conv_in_height_; + int conv_in_width_; + int kernel_dim_; + int weight_offset_; + int col_offset_; + int output_offset_; + + //memory reduced change before + //Blob col_buffer_; + //memory reduced change before end + static Blob col_buffer_; + Blob bias_multiplier_; +}; + +/** + * @brief Convolves the input image with a bank of learned filters, + * and (optionally) adds biases. + * + * Caffe convolves by reduction to matrix multiplication. This achieves + * high-throughput and generality of input and filter dimensions but comes at + * the cost of memory for matrices. This makes use of efficiency in BLAS. + * + * The input is "im2col" transformed to a channel K' x H x W data matrix + * for multiplication with the N x K' x H x W filter matrix to yield a + * N' x H x W output matrix that is then "col2im" restored. K' is the + * input channel * kernel height * kernel width dimension of the unrolled + * inputs so that the im2col matrix has a column for each input region to + * be filtered. col2im restores the output spatial structure by rolling up + * the output channel N' columns of the output matrix. + */ +template +class ConvolutionLayer : public BaseConvolutionLayer { + public: + /** + * @param param provides ConvolutionParameter convolution_param, + * with ConvolutionLayer options: + * - num_output. The number of filters. + * - kernel_size / kernel_h / kernel_w. The filter dimensions, given by + * kernel_size for square filters or kernel_h and kernel_w for rectangular + * filters. + * - stride / stride_h / stride_w (\b optional, default 1). The filter + * stride, given by stride_size for equal dimensions or stride_h and stride_w + * for different strides. By default the convolution is dense with stride 1. + * - pad / pad_h / pad_w (\b optional, default 0). The zero-padding for + * convolution, given by pad for equal dimensions or pad_h and pad_w for + * different padding. Input padding is computed implicitly instead of + * actually padding. + * - group (\b optional, default 1). The number of filter groups. Group + * convolution is a method for reducing parameterization by selectively + * connecting input and output channels. The input and output channel dimensions must be divisible + * by the number of groups. For group @f$ \geq 1 @f$, the + * convolutional filters' input and output channels are separated s.t. each + * group takes 1 / group of the input channels and makes 1 / group of the + * output channels. Concretely 4 input channels, 8 output channels, and + * 2 groups separate input channels 1-2 and output channels 1-4 into the + * first group and input channels 3-4 and output channels 5-8 into the second + * group. + * - bias_term (\b optional, default true). Whether to have a bias. + * - engine: convolution has CAFFE (matrix multiplication) and CUDNN (library + * kernels + stream parallelism) engines. + */ + explicit ConvolutionLayer(const LayerParameter& param) + : BaseConvolutionLayer(param) {} + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_CONVOLUTION; + } + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(this->kernel_h_, this->kernel_w_, this->stride_h_, + this->stride_w_, this->pad_h_, this->pad_w_).inv(); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual inline bool reverse_dimensions() { return false; } + virtual void compute_output_shape(); +}; + +template +class DeconvolutionLayer : public BaseConvolutionLayer { + public: + explicit DeconvolutionLayer(const LayerParameter& param) + : BaseConvolutionLayer(param) {} + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_DECONVOLUTION; + } + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(this->kernel_h_, this->kernel_w_, this->stride_h_, + this->stride_w_, this->pad_h_, this->pad_w_); + } + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual inline bool reverse_dimensions() { return true; } + virtual void compute_output_shape(); +}; + +#ifdef USE_CUDNN +/* + * @brief cuDNN implementation of ConvolutionLayer. + * Fallback to ConvolutionLayer for CPU mode. + * + * cuDNN accelerates convolution through forward kernels for filtering and bias + * plus backward kernels for the gradient w.r.t. the filters, biases, and + * inputs. Caffe + cuDNN further speeds up the computation through forward + * parallelism across groups and backward parallelism across gradients. + * + * The CUDNN engine does not have memory overhead for matrix buffers. For many + * input and filter regimes the CUDNN engine is faster than the CAFFE engine, + * but for fully-convolutional models and large inputs the CAFFE engine can be + * faster as long as it fits in memory. +*/ +template +class CuDNNConvolutionLayer : public ConvolutionLayer { + public: + explicit CuDNNConvolutionLayer(const LayerParameter& param) + : ConvolutionLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNConvolutionLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t* handle_; + cudaStream_t* stream_; + vector bottom_descs_, top_descs_; + cudnnTensorDescriptor_t bias_desc_; + cudnnFilterDescriptor_t filter_desc_; + vector conv_descs_; + int bottom_offset_, top_offset_, weight_offset_, bias_offset_; + size_t workspaceSizeInBytes; + void *workspace; +}; +#endif + +/*! + * \brief A helper class for {@link MultiStageMeanfieldLayer} class, which is the Caffe layer that implements the + * CRF-RNN described in the paper: Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * This class itself is not a proper Caffe layer although it behaves like one to some degree. + * + * \authors Sadeep Jayasumana, Bernardino Romera-Paredes, Shuai Zheng, Zhizhong Su. + * \version 1.0 + * \date 2015 + * \copyright Torr Vision Group, University of Oxford. + * \details If you use this code, please consider citing the paper: + * Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, + * Chang Huang, Philip H. S. Torr. Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * For more information about CRF-RNN, please visit the project website http://crfasrnn.torr.vision. + */ +template +class MeanfieldIteration { + + public: + /** + * Must be invoked only once after the construction of the layer. + */ + void OneTimeSetUp( + Blob* const unary_terms, + Blob* const softmax_input, + Blob* const output_blob, + const shared_ptr spatial_lattice, + const Blob* const spatial_norm); + + /** + * Must be invoked before invoking {@link Forward_cpu()} + */ + virtual void PrePass( + const vector > >& parameters_to_copy_from, + const vector >* const bilateral_lattices, + const Blob* const bilateral_norms); + + /** + * Forward pass - to be called during inference. + */ + virtual void Forward_cpu(); + + /** + * Backward pass - to be called during training. + */ + virtual void Backward_cpu(); + + // A quick hack. This should be properly encapsulated. + vector > >& blobs() { + return blobs_; + } + + protected: + vector > > blobs_; + + int count_; + int num_; + int channels_; + int height_; + int width_; + int num_pixels_; + + Blob spatial_out_blob_; + Blob bilateral_out_blob_; + Blob pairwise_; + Blob softmax_input_; + Blob prob_; + Blob message_passing_; + + vector*> softmax_top_vec_; + vector*> softmax_bottom_vec_; + vector*> sum_top_vec_; + vector*> sum_bottom_vec_; + + shared_ptr > softmax_layer_; + shared_ptr > sum_layer_; + + shared_ptr spatial_lattice_; + const vector >* bilateral_lattices_; + + const Blob* spatial_norm_; + const Blob* bilateral_norms_; + +}; + +/*! + * \brief The Caffe layer that implements the CRF-RNN described in the paper: + * Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * \authors Sadeep Jayasumana, Bernardino Romera-Paredes, Shuai Zheng, Zhizhong Su. + * \version 1.0 + * \date 2015 + * \copyright Torr Vision Group, University of Oxford. + * \details If you use this code, please consider citing the paper: + * Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, + * Chang Huang, Philip H. S. Torr. Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * For more information about CRF-RNN, please visit the project website http://crfasrnn.torr.vision. + */ +template +class MultiStageMeanfieldLayer : public Layer { + + public: + explicit MultiStageMeanfieldLayer(const LayerParameter& param) : Layer(param) {} + + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_MULTI_STAGE_MEANFIELD; + } + virtual inline int ExactNumBottomBlobs() const { return 3; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + virtual void compute_spatial_kernel(float* const output_kernel); + virtual void compute_bilateral_kernel(const Blob* const rgb_blob, const int n, float* const output_kernel); + + int count_; + int num_; + int channels_; + int height_; + int width_; + int num_pixels_; + + Dtype theta_alpha_; + Dtype theta_beta_; + Dtype theta_gamma_; + int num_iterations_; + + boost::shared_array norm_feed_; + Blob spatial_norm_; + Blob bilateral_norms_; + + vector*> split_layer_bottom_vec_; + vector*> split_layer_top_vec_; + vector > > split_layer_out_blobs_; + vector > > iteration_output_blobs_; + vector > > meanfield_iterations_; + + shared_ptr > split_layer_; + + shared_ptr spatial_lattice_; + boost::shared_array bilateral_kernel_buffer_; + vector > bilateral_lattices_; +}; + + +/** + * @brief A helper for image operations that rearranges image regions into + * column vectors. Used by ConvolutionLayer to perform convolution + * by matrix multiplication. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class Im2colLayer : public Layer { + public: + explicit Im2colLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_IM2COL; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int kernel_h_, kernel_w_; + int stride_h_, stride_w_; + int channels_; + int height_, width_; + int pad_h_, pad_w_; +}; + +// Forward declare PoolingLayer and SplitLayer for use in LRNLayer. +template class PoolingLayer; +template class SplitLayer; + +/** + * @brief Normalize the input in a local region across or within feature maps. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class LRNLayer : public Layer { + public: + explicit LRNLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_LRN; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + return DiagonalAffineMap::identity(2); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + virtual void CrossChannelForward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void CrossChannelForward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void WithinChannelForward(const vector*>& bottom, + const vector*>& top); + virtual void CrossChannelBackward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void CrossChannelBackward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void WithinChannelBackward(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int size_; + int pre_pad_; + Dtype alpha_; + Dtype beta_; + Dtype k_; + int num_; + int channels_; + int height_; + int width_; + + // Fields used for normalization ACROSS_CHANNELS + // scale_ stores the intermediate summing results + Blob scale_; + + // Fields used for normalization WITHIN_CHANNEL + shared_ptr > split_layer_; + vector*> split_top_vec_; + shared_ptr > square_layer_; + Blob square_input_; + Blob square_output_; + vector*> square_bottom_vec_; + vector*> square_top_vec_; + shared_ptr > pool_layer_; + Blob pool_output_; + vector*> pool_top_vec_; + shared_ptr > power_layer_; + Blob power_output_; + vector*> power_top_vec_; + shared_ptr > product_layer_; + Blob product_input_; + vector*> product_bottom_vec_; +}; + + +/** + * @brief Pools the input image by taking the max, average, etc. within regions. + * + * TODO(dox): thorough documentation for Forward, Backward, and proto params. + */ +template +class PoolingLayer : public Layer { + public: + explicit PoolingLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_POOLING; + } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 1; } + // MAX POOL layers can output an extra top blob for the mask; + // others can only output the pooled inputs. + virtual inline int MaxTopBlobs() const { + return (this->layer_param_.pooling_param().pool() == + PoolingParameter_PoolMethod_MAX) ? 2 : 1; + } + virtual inline DiagonalAffineMap coord_map() { + return FilterMap(kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_h_, pad_w_).inv(); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int kernel_h_, kernel_w_; + int stride_h_, stride_w_; + int pad_h_, pad_w_; + int channels_; + int height_, width_; + int pooled_height_, pooled_width_; + bool global_pooling_; + Blob rand_idx_; + Blob max_idx_; +}; + +#ifdef USE_CUDNN +/* + * @brief cuDNN implementation of PoolingLayer. + * Fallback to PoolingLayer for CPU mode. +*/ +template +class CuDNNPoolingLayer : public PoolingLayer { + public: + explicit CuDNNPoolingLayer(const LayerParameter& param) + : PoolingLayer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNPoolingLayer(); + // Currently, cuDNN does not support the extra top blob. + virtual inline int MinTopBlobs() const { return -1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; + cudnnPoolingDescriptor_t pooling_desc_; + cudnnPoolingMode_t mode_; +}; +#endif + +template +class CropLayer : public Layer { + public: + explicit CropLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_CROP; + } + virtual inline int ExactNumBottomBlobs() const { return 2; } + virtual inline int ExactNumTopBlobs() const { return 1; } + virtual inline DiagonalAffineMap coord_map() { + vector > coefs; + coefs.push_back(make_pair(1, - crop_h_)); + coefs.push_back(make_pair(1, - crop_w_)); + return DiagonalAffineMap(coefs); + } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + int crop_h_, crop_w_; +}; + +} // namespace caffe + +#endif // CAFFE_VISION_LAYERS_HPP_ diff --git a/caffe-crfrnn/matlab/CMakeLists.txt b/caffe-crfrnn/matlab/CMakeLists.txt new file mode 100644 index 00000000..f420df8d --- /dev/null +++ b/caffe-crfrnn/matlab/CMakeLists.txt @@ -0,0 +1,72 @@ +# Builds Matlab (or Octave) interface. In case of Matlab caffe must be +# compield as shared library. Octave can link static or shared caffe library +# To install octave run: sudo apt-get install liboctave-dev + +if(NOT BUILD_matlab) + return() +endif() + +if(HAVE_MATLAB AND Octave_compiler) + set(build_using ${Matlab_build_mex_using}) +elseif(HAVE_MATLAB AND NOT Octave_compiler) + set(build_using "Matlab") +elseif(NOT HAVE_MATLAB AND Octave_compiler) + set(build_using "Octave") +else() + return() +endif() + +if(NOT BUILD_SHARED_LIBS AND build_using MATCHES Matlab) + message(FATAL_ERROR "Matlab MEX interface (with default mex options file) can only be built if caffe is compiled as shared library. Please enable 'BUILD_SHARED_LIBS' in CMake. Aternativelly you can switch to Octave compiler.") +endif() + +# helper function to set proper mex file extention +function(caffe_fetch_and_set_proper_mexext mexfile_variable) + execute_process(COMMAND ${Matlab_mexext} OUTPUT_STRIP_TRAILING_WHITESPACE RESULT_VARIABLE res OUTPUT_VARIABLE ext) + if(res MATCHES 0) + get_filename_component(folder ${${mexfile_variable}} PATH) + get_filename_component(name_we ${${mexfile_variable}} NAME_WE) + set(${mexfile_variable} ${folder}/${name_we}.${ext} PARENT_SCOPE) + endif() +endfunction() + +# global settings +file(GLOB Matlab_srcs +caffe/private/caffe_.cpp) +set(Matlab_caffe_mex ${PROJECT_SOURCE_DIR}/matlab/+caffe/private/caffe_.mex) + +caffe_get_current_cflags(cflags) +caffe_parse_linker_libs(Caffe_LINKER_LIBS folders libflags macos_frameworks) +set(folders $ ${folders}) + +# prepare linker flag lists +string(REPLACE ";" ";-L" link_folders "-L${folders}") +string(REPLACE ";" ":" rpath_folders "${folders}") + +if(build_using MATCHES "Matlab") + set(libflags -lcaffe${Caffe_POSTFIX} ${libflags}) # Matlab R2014a complans for -Wl,--whole-archive + + caffe_fetch_and_set_proper_mexext(Matlab_caffe_mex) + add_custom_command(OUTPUT ${Matlab_caffe_mex} COMMAND ${Matlab_mex} + ARGS -output ${Matlab_caffe_mex} ${Matlab_srcs} ${cflags} ${link_folders} ${libflags} + DEPENDS caffe COMMENT "Building Matlab interface: ${Matlab_caffe_mex}" VERBATIM) + add_custom_target(matlab ALL DEPENDS ${Matlab_caffe_mex} SOURCES ${Matlab_srcs}) + +elseif(build_using MATCHES "Octave") + + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + set(libflags -Wl,-force_load,$ ${libflags}) + elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(libflags -Wl,--whole-archive -lcaffe${Caffe_POSTFIX} -Wl,--no-whole-archive ${libflags}) + endif() + + add_custom_command(OUTPUT ${Matlab_caffe_mex} COMMAND ${Octave_compiler} + ARGS --mex -o ${Matlab_caffe_mex} ${Matlab_srcs} ${cflags} ${link_folders} ${libflags} -Wl,-rpath,${rpath_folders} + DEPENDS caffe COMMENT "Building Octave interface: ${Matlab_caffe_mex}" VERBATIM) + + add_custom_target(octave ALL DEPENDS ${Matlab_caffe_mex} SOURCES ${Matlab_srcs}) +endif() + +# ---[ Install +file(GLOB mfiles caffe/*.m) +install(FILES ${mfiles} ${Matlab_caffe_mex} DESTINATION matlab) + diff --git a/caffe-crfrnn/matlab/caffe/.gitignore b/caffe-crfrnn/matlab/caffe/.gitignore new file mode 100644 index 00000000..56c01d95 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/.gitignore @@ -0,0 +1 @@ +/caffe.mexa64 diff --git a/caffe-crfrnn/matlab/caffe/ilsvrc_2012_mean.mat b/caffe-crfrnn/matlab/caffe/ilsvrc_2012_mean.mat new file mode 100644 index 00000000..f1da25c8 Binary files /dev/null and b/caffe-crfrnn/matlab/caffe/ilsvrc_2012_mean.mat differ diff --git a/caffe-crfrnn/matlab/caffe/matcaffe.cpp b/caffe-crfrnn/matlab/caffe/matcaffe.cpp new file mode 100644 index 00000000..3de0f02e --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe.cpp @@ -0,0 +1,420 @@ +// +// matcaffe.cpp provides a wrapper of the caffe::Net class as well as some +// caffe::Caffe functions so that one could easily call it from matlab. +// Note that for matlab, we will simply use float as the data type. + +#include +#include +#include + +#include "mex.h" + +#include "caffe/caffe.hpp" + +#define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs + +// Log and throw a Mex error +inline void mex_error(const std::string &msg) { + LOG(ERROR) << msg; + mexErrMsgTxt(msg.c_str()); +} + +using namespace caffe; // NOLINT(build/namespaces) + +// The pointer to the internal caffe::Net instance +static shared_ptr > net_; +static int init_key = -2; + +// Five things to be aware of: +// caffe uses row-major order +// matlab uses column-major order +// caffe uses BGR color channel order +// matlab uses RGB color channel order +// images need to have the data mean subtracted +// +// Data coming in from matlab needs to be in the order +// [width, height, channels, images] +// where width is the fastest dimension. +// Here is the rough matlab for putting image data into the correct +// format: +// % convert from uint8 to single +// im = single(im); +// % reshape to a fixed size (e.g., 227x227) +// im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear'); +// % permute from RGB to BGR and subtract the data mean (already in BGR) +// im = im(:,:,[3 2 1]) - data_mean; +// % flip width and height to make width the fastest dimension +// im = permute(im, [2 1 3]); +// +// If you have multiple images, cat them with cat(4, ...) +// +// The actual forward function. It takes in a cell array of 4-D arrays as +// input and outputs a cell array. + +static mxArray* do_forward(const mxArray* const bottom) { + vector*>& input_blobs = net_->input_blobs(); + if (static_cast(mxGetDimensions(bottom)[0]) != + input_blobs.size()) { + mex_error("Invalid input size"); + } + for (unsigned int i = 0; i < input_blobs.size(); ++i) { + const mxArray* const elem = mxGetCell(bottom, i); + if (!mxIsSingle(elem)) { + mex_error("MatCaffe require single-precision float point data"); + } + if (mxGetNumberOfElements(elem) != input_blobs[i]->count()) { + std::string error_msg; + error_msg += "MatCaffe input size does not match the input size "; + error_msg += "of the network"; + mex_error(error_msg); + } + + const float* const data_ptr = + reinterpret_cast(mxGetPr(elem)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(input_blobs[i]->count(), data_ptr, + input_blobs[i]->mutable_cpu_data()); + break; + case Caffe::GPU: + caffe_copy(input_blobs[i]->count(), data_ptr, + input_blobs[i]->mutable_gpu_data()); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + const vector*>& output_blobs = net_->ForwardPrefilled(); + mxArray* mx_out = mxCreateCellMatrix(output_blobs.size(), 1); + for (unsigned int i = 0; i < output_blobs.size(); ++i) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {output_blobs[i]->width(), output_blobs[i]->height(), + output_blobs[i]->channels(), output_blobs[i]->num()}; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_out, i, mx_blob); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(output_blobs[i]->count(), output_blobs[i]->cpu_data(), + data_ptr); + break; + case Caffe::GPU: + caffe_copy(output_blobs[i]->count(), output_blobs[i]->gpu_data(), + data_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + + return mx_out; +} + +static mxArray* do_backward(const mxArray* const top_diff) { + vector*>& output_blobs = net_->output_blobs(); + vector*>& input_blobs = net_->input_blobs(); + if (static_cast(mxGetDimensions(top_diff)[0]) != + output_blobs.size()) { + mex_error("Invalid input size"); + } + // First, copy the output diff + for (unsigned int i = 0; i < output_blobs.size(); ++i) { + const mxArray* const elem = mxGetCell(top_diff, i); + const float* const data_ptr = + reinterpret_cast(mxGetPr(elem)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(output_blobs[i]->count(), data_ptr, + output_blobs[i]->mutable_cpu_diff()); + break; + case Caffe::GPU: + caffe_copy(output_blobs[i]->count(), data_ptr, + output_blobs[i]->mutable_gpu_diff()); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + // LOG(INFO) << "Start"; + net_->Backward(); + // LOG(INFO) << "End"; + mxArray* mx_out = mxCreateCellMatrix(input_blobs.size(), 1); + for (unsigned int i = 0; i < input_blobs.size(); ++i) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {input_blobs[i]->width(), input_blobs[i]->height(), + input_blobs[i]->channels(), input_blobs[i]->num()}; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_out, i, mx_blob); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(input_blobs[i]->count(), input_blobs[i]->cpu_diff(), data_ptr); + break; + case Caffe::GPU: + caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + + return mx_out; +} + +static mxArray* do_get_weights() { + const vector > >& layers = net_->layers(); + const vector& layer_names = net_->layer_names(); + + // Step 1: count the number of layers with weights + int num_layers = 0; + { + string prev_layer_name = ""; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + num_layers++; + } + } + } + + // Step 2: prepare output array of structures + mxArray* mx_layers; + { + const mwSize dims[2] = {num_layers, 1}; + const char* fnames[2] = {"weights", "layer_names"}; + mx_layers = mxCreateStructArray(2, dims, 2, fnames); + } + + // Step 3: copy weights into output + { + string prev_layer_name = ""; + int mx_layer_index = 0; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + + mxArray* mx_layer_cells = NULL; + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + const mwSize dims[2] = {static_cast(layer_blobs.size()), 1}; + mx_layer_cells = mxCreateCellArray(2, dims); + mxSetField(mx_layers, mx_layer_index, "weights", mx_layer_cells); + mxSetField(mx_layers, mx_layer_index, "layer_names", + mxCreateString(layer_names[i].c_str())); + mx_layer_index++; + } + + for (unsigned int j = 0; j < layer_blobs.size(); ++j) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {layer_blobs[j]->width(), layer_blobs[j]->height(), + layer_blobs[j]->channels(), layer_blobs[j]->num()}; + + mxArray* mx_weights = + mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_layer_cells, j, mx_weights); + float* weights_ptr = reinterpret_cast(mxGetPr(mx_weights)); + + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->cpu_data(), + weights_ptr); + break; + case Caffe::GPU: + caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->gpu_data(), + weights_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } + } + } + } + + return mx_layers; +} + +static void get_weights(MEX_ARGS) { + plhs[0] = do_get_weights(); +} + +static void set_mode_cpu(MEX_ARGS) { + Caffe::set_mode(Caffe::CPU); +} + +static void set_mode_gpu(MEX_ARGS) { + Caffe::set_mode(Caffe::GPU); +} + +static void set_phase_train(MEX_ARGS) { + Caffe::set_phase(Caffe::TRAIN); +} + +static void set_phase_test(MEX_ARGS) { + Caffe::set_phase(Caffe::TEST); +} + +static void set_device(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + int device_id = static_cast(mxGetScalar(prhs[0])); + Caffe::SetDevice(device_id); +} + +static void get_init_key(MEX_ARGS) { + plhs[0] = mxCreateDoubleScalar(init_key); +} + +static void init(MEX_ARGS) { + if (nrhs != 2) { + ostringstream error_msg; + error_msg << "Expected 2 arguments, got " << nrhs; + mex_error(error_msg.str()); + } + + char* param_file = mxArrayToString(prhs[0]); + char* model_file = mxArrayToString(prhs[1]); + + net_.reset(new Net(string(param_file))); + net_->CopyTrainedLayersFrom(string(model_file)); + + mxFree(param_file); + mxFree(model_file); + + init_key = random(); // NOLINT(caffe/random_fn) + + if (nlhs == 1) { + plhs[0] = mxCreateDoubleScalar(init_key); + } +} + +static void reset(MEX_ARGS) { + if (net_) { + net_.reset(); + init_key = -2; + LOG(INFO) << "Network reset, call init before use it again"; + } +} + +static void forward(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + plhs[0] = do_forward(prhs[0]); +} + +static void backward(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + plhs[0] = do_backward(prhs[0]); +} + +static void is_initialized(MEX_ARGS) { + if (!net_) { + plhs[0] = mxCreateDoubleScalar(0); + } else { + plhs[0] = mxCreateDoubleScalar(1); + } +} + +static void read_mean(MEX_ARGS) { + if (nrhs != 1) { + mexErrMsgTxt("Usage: caffe('read_mean', 'path_to_binary_mean_file'"); + return; + } + const string& mean_file = mxArrayToString(prhs[0]); + Blob data_mean; + LOG(INFO) << "Loading mean file from" << mean_file; + BlobProto blob_proto; + bool result = ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto); + if (!result) { + mexErrMsgTxt("Couldn't read the file"); + return; + } + data_mean.FromProto(blob_proto); + mwSize dims[4] = {data_mean.width(), data_mean.height(), + data_mean.channels(), data_mean.num() }; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + caffe_copy(data_mean.count(), data_mean.cpu_data(), data_ptr); + mexWarnMsgTxt("Remember that Caffe saves in [width, height, channels]" + " format and channels are also BGR!"); + plhs[0] = mx_blob; +} + +/** ----------------------------------------------------------------- + ** Available commands. + **/ +struct handler_registry { + string cmd; + void (*func)(MEX_ARGS); +}; + +static handler_registry handlers[] = { + // Public API functions + { "forward", forward }, + { "backward", backward }, + { "init", init }, + { "is_initialized", is_initialized }, + { "set_mode_cpu", set_mode_cpu }, + { "set_mode_gpu", set_mode_gpu }, + { "set_phase_train", set_phase_train }, + { "set_phase_test", set_phase_test }, + { "set_device", set_device }, + { "get_weights", get_weights }, + { "get_init_key", get_init_key }, + { "reset", reset }, + { "read_mean", read_mean }, + // The end. + { "END", NULL }, +}; + + +/** ----------------------------------------------------------------- + ** matlab entry point: caffe(api_command, arg1, arg2, ...) + **/ +void mexFunction(MEX_ARGS) { + mexLock(); // Avoid clearing the mex file. + if (nrhs == 0) { + mex_error("No API command given"); + return; + } + + { // Handle input command + char *cmd = mxArrayToString(prhs[0]); + bool dispatched = false; + // Dispatch to cmd handler + for (int i = 0; handlers[i].func != NULL; i++) { + if (handlers[i].cmd.compare(cmd) == 0) { + handlers[i].func(nlhs, plhs, nrhs-1, prhs+1); + dispatched = true; + break; + } + } + if (!dispatched) { + ostringstream error_msg; + error_msg << "Unknown command '" << cmd << "'"; + mex_error(error_msg.str()); + } + mxFree(cmd); + } +} diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_batch.m b/caffe-crfrnn/matlab/caffe/matcaffe_batch.m new file mode 100644 index 00000000..f6d1aa83 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_batch.m @@ -0,0 +1,75 @@ +function [scores,list_im] = matcaffe_batch(list_im, use_gpu) +% scores = matcaffe_batch(list_im, use_gpu) +% +% Demo of the matlab wrapper using the ILSVRC network. +% +% input +% list_im list of images files +% use_gpu 1 to use the GPU, 0 to use the CPU +% +% output +% scores 1000 x num_images ILSVRC output vector +% +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system +% +% Usage: +% scores = matcaffe_batch({'peppers.png','onion.png'}); +% scores = matcaffe_batch('list_images.txt', 1); +if nargin < 1 + % For test purposes + list_im = {'peppers.png','onions.png'}; +end +if ischar(list_im) + %Assume it is a file contaning the list of images + filename = list_im; + list_im = read_cell(filename); +end +% Adjust the batch size and dim to match with models/bvlc_reference_caffenet/deploy.prototxt +batch_size = 10; +dim = 1000; +disp(list_im) +if mod(length(list_im),batch_size) + warning(['Assuming batches of ' num2str(batch_size) ' images rest will be filled with zeros']) +end + +% init caffe network (spews logging info) +if exist('use_gpu', 'var') + matcaffe_init(use_gpu); +else + matcaffe_init(); +end + +d = load('ilsvrc_2012_mean'); +IMAGE_MEAN = d.image_mean; + +% prepare input + +num_images = length(list_im); +scores = zeros(dim,num_images,'single'); +num_batches = ceil(length(list_im)/batch_size) +initic=tic; +for bb = 1 : num_batches + batchtic = tic; + range = 1+batch_size*(bb-1):min(num_images,batch_size * bb); + tic + input_data = prepare_batch(list_im(range),IMAGE_MEAN,batch_size); + toc, tic + fprintf('Batch %d out of %d %.2f%% Complete ETA %.2f seconds\n',... + bb,num_batches,bb/num_batches*100,toc(initic)/bb*(num_batches-bb)); + output_data = caffe('forward', {input_data}); + toc + output_data = squeeze(output_data{1}); + scores(:,range) = output_data(:,mod(range-1,batch_size)+1); + toc(batchtic) +end +toc(initic); + +if exist('filename', 'var') + save([filename '.probs.mat'],'list_im','scores','-v7.3'); +end + + + diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_demo.m b/caffe-crfrnn/matlab/caffe/matcaffe_demo.m new file mode 100644 index 00000000..a931f910 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_demo.m @@ -0,0 +1,110 @@ +function [scores, maxlabel] = matcaffe_demo(im, use_gpu) +% scores = matcaffe_demo(im, use_gpu) +% +% Demo of the matlab wrapper using the ILSVRC network. +% +% input +% im color image as uint8 HxWx3 +% use_gpu 1 to use the GPU, 0 to use the CPU +% +% output +% scores 1000-dimensional ILSVRC score vector +% +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda-5.5/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system +% +% Usage: +% im = imread('../../examples/images/cat.jpg'); +% scores = matcaffe_demo(im, 1); +% [score, class] = max(scores); +% Five things to be aware of: +% caffe uses row-major order +% matlab uses column-major order +% caffe uses BGR color channel order +% matlab uses RGB color channel order +% images need to have the data mean subtracted + +% Data coming in from matlab needs to be in the order +% [width, height, channels, images] +% where width is the fastest dimension. +% Here is the rough matlab for putting image data into the correct +% format: +% % convert from uint8 to single +% im = single(im); +% % reshape to a fixed size (e.g., 227x227) +% im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear'); +% % permute from RGB to BGR and subtract the data mean (already in BGR) +% im = im(:,:,[3 2 1]) - data_mean; +% % flip width and height to make width the fastest dimension +% im = permute(im, [2 1 3]); + +% If you have multiple images, cat them with cat(4, ...) + +% The actual forward function. It takes in a cell array of 4-D arrays as +% input and outputs a cell array. + + +% init caffe network (spews logging info) +if exist('use_gpu', 'var') + matcaffe_init(use_gpu); +else + matcaffe_init(); +end + +if nargin < 1 + % For demo purposes we will use the peppers image + im = imread('peppers.png'); +end + +% prepare oversampled input +% input_data is Height x Width x Channel x Num +tic; +input_data = {prepare_image(im)}; +toc; + +% do forward pass to get scores +% scores are now Width x Height x Channels x Num +tic; +scores = caffe('forward', input_data); +toc; + +scores = scores{1}; +size(scores) +scores = squeeze(scores); +scores = mean(scores,2); + +[~,maxlabel] = max(scores); + +% ------------------------------------------------------------------------ +function images = prepare_image(im) +% ------------------------------------------------------------------------ +d = load('ilsvrc_2012_mean'); +IMAGE_MEAN = d.image_mean; +IMAGE_DIM = 256; +CROPPED_DIM = 227; + +% resize to fixed input size +im = single(im); +im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear'); +% permute from RGB to BGR (IMAGE_MEAN is already BGR) +im = im(:,:,[3 2 1]) - IMAGE_MEAN; + +% oversample (4 corners, center, and their x-axis flips) +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); +indices = [0 IMAGE_DIM-CROPPED_DIM] + 1; +curr = 1; +for i = indices + for j = indices + images(:, :, :, curr) = ... + permute(im(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :), [2 1 3]); + images(:, :, :, curr+5) = images(end:-1:1, :, :, curr); + curr = curr + 1; + end +end +center = floor(indices(2) / 2)+1; +images(:,:,:,5) = ... + permute(im(center:center+CROPPED_DIM-1,center:center+CROPPED_DIM-1,:), ... + [2 1 3]); +images(:,:,:,10) = images(end:-1:1, :, :, curr); diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg.m b/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg.m new file mode 100644 index 00000000..4e5a98eb --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg.m @@ -0,0 +1,96 @@ +function scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file) +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file) +% +% Demo of the matlab wrapper using the networks described in the BMVC-2014 paper "Return of the Devil in the Details: Delving Deep into Convolutional Nets" +% +% INPUT +% im - color image as uint8 HxWx3 +% use_gpu - 1 to use the GPU, 0 to use the CPU +% model_def_file - network configuration (.prototxt file) +% model_file - network weights (.caffemodel file) +% mean_file - mean BGR image as uint8 HxWx3 (.mat file) +% +% OUTPUT +% scores 1000-dimensional ILSVRC score vector +% +% EXAMPLE USAGE +% model_def_file = 'zoo/VGG_CNN_F_deploy.prototxt'; +% model_file = 'zoo/VGG_CNN_F.caffemodel'; +% mean_file = 'zoo/VGG_mean.mat'; +% use_gpu = true; +% im = imread('../../examples/images/cat.jpg'); +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file, mean_file); +% +% NOTES +% the image crops are prepared as described in the paper (the aspect ratio is preserved) +% +% PREREQUISITES +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system + +% init caffe network (spews logging info) +matcaffe_init(use_gpu, model_def_file, model_file); + +% prepare oversampled input +% input_data is Height x Width x Channel x Num +tic; +input_data = {prepare_image(im, mean_file)}; +toc; + +% do forward pass to get scores +% scores are now Width x Height x Channels x Num +tic; +scores = caffe('forward', input_data); +toc; + +scores = scores{1}; +% size(scores) +scores = squeeze(scores); +% scores = mean(scores,2); + +% [~,maxlabel] = max(scores); + +% ------------------------------------------------------------------------ +function images = prepare_image(im, mean_file) +% ------------------------------------------------------------------------ +IMAGE_DIM = 256; +CROPPED_DIM = 224; + +d = load(mean_file); +IMAGE_MEAN = d.image_mean; + +% resize to fixed input size +im = single(im); + +if size(im, 1) < size(im, 2) + im = imresize(im, [IMAGE_DIM NaN]); +else + im = imresize(im, [NaN IMAGE_DIM]); +end + +% RGB -> BGR +im = im(:, :, [3 2 1]); + +% oversample (4 corners, center, and their x-axis flips) +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); + +indices_y = [0 size(im,1)-CROPPED_DIM] + 1; +indices_x = [0 size(im,2)-CROPPED_DIM] + 1; +center_y = floor(indices_y(2) / 2)+1; +center_x = floor(indices_x(2) / 2)+1; + +curr = 1; +for i = indices_y + for j = indices_x + images(:, :, :, curr) = ... + permute(im(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :)-IMAGE_MEAN, [2 1 3]); + images(:, :, :, curr+5) = images(end:-1:1, :, :, curr); + curr = curr + 1; + end +end +images(:,:,:,5) = ... + permute(im(center_y:center_y+CROPPED_DIM-1,center_x:center_x+CROPPED_DIM-1,:)-IMAGE_MEAN, ... + [2 1 3]); +images(:,:,:,10) = images(end:-1:1, :, :, curr); diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg_mean_pix.m b/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg_mean_pix.m new file mode 100644 index 00000000..5f7898a7 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_demo_vgg_mean_pix.m @@ -0,0 +1,102 @@ +function scores = matcaffe_demo_vgg_mean_pix(im, use_gpu, model_def_file, model_file) +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file) +% +% Demo of the matlab wrapper based on the networks used for the "VGG" entry +% in the ILSVRC-2014 competition and described in the tech. report +% "Very Deep Convolutional Networks for Large-Scale Image Recognition" +% http://arxiv.org/abs/1409.1556/ +% +% INPUT +% im - color image as uint8 HxWx3 +% use_gpu - 1 to use the GPU, 0 to use the CPU +% model_def_file - network configuration (.prototxt file) +% model_file - network weights (.caffemodel file) +% +% OUTPUT +% scores 1000-dimensional ILSVRC score vector +% +% EXAMPLE USAGE +% model_def_file = 'zoo/deploy.prototxt'; +% model_file = 'zoo/model.caffemodel'; +% use_gpu = true; +% im = imread('../../examples/images/cat.jpg'); +% scores = matcaffe_demo_vgg(im, use_gpu, model_def_file, model_file); +% +% NOTES +% mean pixel subtraction is used instead of the mean image subtraction +% +% PREREQUISITES +% You may need to do the following before you start matlab: +% $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64 +% $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6 +% Or the equivalent based on where things are installed on your system + +% init caffe network (spews logging info) +matcaffe_init(use_gpu, model_def_file, model_file); + +% mean BGR pixel +mean_pix = [103.939, 116.779, 123.68]; + +% prepare oversampled input +% input_data is Height x Width x Channel x Num +tic; +input_data = {prepare_image(im, mean_pix)}; +toc; + +% do forward pass to get scores +% scores are now Width x Height x Channels x Num +tic; +scores = caffe('forward', input_data); +toc; + +scores = scores{1}; +% size(scores) +scores = squeeze(scores); +% scores = mean(scores,2); + +% [~,maxlabel] = max(scores); + +% ------------------------------------------------------------------------ +function images = prepare_image(im, mean_pix) +% ------------------------------------------------------------------------ +IMAGE_DIM = 256; +CROPPED_DIM = 224; + +% resize to fixed input size +im = single(im); + +if size(im, 1) < size(im, 2) + im = imresize(im, [IMAGE_DIM NaN]); +else + im = imresize(im, [NaN IMAGE_DIM]); +end + +% RGB -> BGR +im = im(:, :, [3 2 1]); + +% oversample (4 corners, center, and their x-axis flips) +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, 10, 'single'); + +indices_y = [0 size(im,1)-CROPPED_DIM] + 1; +indices_x = [0 size(im,2)-CROPPED_DIM] + 1; +center_y = floor(indices_y(2) / 2)+1; +center_x = floor(indices_x(2) / 2)+1; + +curr = 1; +for i = indices_y + for j = indices_x + images(:, :, :, curr) = ... + permute(im(i:i+CROPPED_DIM-1, j:j+CROPPED_DIM-1, :), [2 1 3]); + images(:, :, :, curr+5) = images(end:-1:1, :, :, curr); + curr = curr + 1; + end +end +images(:,:,:,5) = ... + permute(im(center_y:center_y+CROPPED_DIM-1,center_x:center_x+CROPPED_DIM-1,:), ... + [2 1 3]); +images(:,:,:,10) = images(end:-1:1, :, :, curr); + +% mean BGR pixel subtraction +for c = 1:3 + images(:, :, c, :) = images(:, :, c, :) - mean_pix(c); +end diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_init.m b/caffe-crfrnn/matlab/caffe/matcaffe_init.m new file mode 100644 index 00000000..7cc69357 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_init.m @@ -0,0 +1,44 @@ +function matcaffe_init(use_gpu, model_def_file, model_file) +% matcaffe_init(model_def_file, model_file, use_gpu) +% Initilize matcaffe wrapper + +if nargin < 1 + % By default use CPU + use_gpu = 0; +end +if nargin < 2 || isempty(model_def_file) + % By default use imagenet_deploy + model_def_file = '../../models/bvlc_reference_caffenet/deploy.prototxt'; +end +if nargin < 3 || isempty(model_file) + % By default use caffe reference model + model_file = '../../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'; +end + + +if caffe('is_initialized') == 0 + if exist(model_file, 'file') == 0 + % NOTE: you'll have to get the pre-trained ILSVRC network + error('You need a network model file'); + end + if ~exist(model_def_file,'file') + % NOTE: you'll have to get network definition + error('You need the network prototxt definition'); + end + caffe('init', model_def_file, model_file) +end +fprintf('Done with init\n'); + +% set to use GPU or CPU +if use_gpu + fprintf('Using GPU Mode\n'); + caffe('set_mode_gpu'); +else + fprintf('Using CPU Mode\n'); + caffe('set_mode_cpu'); +end +fprintf('Done with set_mode\n'); + +% put into test mode +caffe('set_phase_test'); +fprintf('Done with set_phase_test\n'); diff --git a/caffe-crfrnn/matlab/caffe/matcaffe_netsurgery.cpp b/caffe-crfrnn/matlab/caffe/matcaffe_netsurgery.cpp new file mode 100644 index 00000000..71fb1aeb --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/matcaffe_netsurgery.cpp @@ -0,0 +1,600 @@ +// +// matcaffe.cpp provides a wrapper of the caffe::Net class as well as some +// caffe::Caffe functions so that one could easily call it from matlab. +// Note that for matlab, we will simply use float as the data type. + +#include +#include +#include + +#include "mex.h" + +#include "caffe/caffe.hpp" + +#define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs + +// Log and throw a Mex error +inline void mex_error(const std::string &msg) { + LOG(ERROR) << msg; + mexErrMsgTxt(msg.c_str()); +} + +// Log and print Warning message +inline void mex_warn(const std::string &msg) { + LOG(INFO) << msg; + mexWarnMsgTxt(msg.c_str()); +} + +using namespace caffe; // NOLINT(build/namespaces) + +// The pointer to the internal caffe::Net instance +static shared_ptr > net_; +static int init_key = -2; + +// Five things to be aware of: +// caffe uses row-major order +// matlab uses column-major order +// caffe uses BGR color channel order +// matlab uses RGB color channel order +// images need to have the data mean subtracted +// +// Data coming in from matlab needs to be in the order +// [width, height, channels, images] +// where width is the fastest dimension. +// Here is the rough matlab for putting image data into the correct +// format: +// % convert from uint8 to single +// im = single(im); +// % reshape to a fixed size (e.g., 227x227) +// im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear'); +// % permute from RGB to BGR and subtract the data mean (already in BGR) +// im = im(:,:,[3 2 1]) - data_mean; +// % flip width and height to make width the fastest dimension +// im = permute(im, [2 1 3]); +// +// If you have multiple images, cat them with cat(4, ...) +// +// The actual forward function. It takes in a cell array of 4-D arrays as +// input and outputs a cell array. + +static mxArray* do_forward(const mxArray* const bottom, + const bool is_conv = false) { + const vector*>& input_blobs = net_->input_blobs(); + if (static_cast(mxGetDimensions(bottom)[0]) != + input_blobs.size()) { + mex_error("Invalid input size"); + } + for (unsigned int i = 0; i < input_blobs.size(); ++i) { + const mxArray* const elem = mxGetCell(bottom, i); + if (!mxIsSingle(elem)) { + mex_error("MatCaffe require single-precision float point data"); + } + if (!is_conv && mxGetNumberOfElements(elem) != input_blobs[i]->count()) { + std::string error_msg; + error_msg += "MatCaffe input size does not match the input size "; + error_msg += "of the network"; + mex_error(error_msg); + } + if (is_conv) { + // allow dynamic input size, when the net is fully convolutional. + const int num_dims = mxGetNumberOfDimensions(elem); + if (num_dims > 4) { + ostringstream error_msg; + error_msg << "Expected input blob has at most 4 dimensions, got " + << num_dims; + mex_error(error_msg.str()); + } + const mwSize* dim = mxGetDimensions(elem); + int width = dim[0]; // width in caffe is the fastest dimension + int height = dim[1]; + int channels = num_dims > 2 ? dim[2] : 1; + int num = num_dims > 3 ? dim[3] : 1; + if (input_blobs[i]->width() != width + || input_blobs[i]->height() != height + || input_blobs[i]->channels() != channels + || input_blobs[i]->num() != num) { + input_blobs[i]->Reshape(num, channels, height, width); + // The shape of other layers will be reshaped when calling forward. + } + } + const float* const data_ptr = + reinterpret_cast(mxGetPr(elem)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(input_blobs[i]->count(), data_ptr, + input_blobs[i]->mutable_cpu_data()); + break; + case Caffe::GPU: + caffe_copy(input_blobs[i]->count(), data_ptr, + input_blobs[i]->mutable_gpu_data()); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + const vector*>& output_blobs = net_->ForwardPrefilled(); + mxArray* mx_out = mxCreateCellMatrix(output_blobs.size(), 1); + for (unsigned int i = 0; i < output_blobs.size(); ++i) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {output_blobs[i]->width(), output_blobs[i]->height(), + output_blobs[i]->channels(), output_blobs[i]->num()}; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_out, i, mx_blob); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(output_blobs[i]->count(), output_blobs[i]->cpu_data(), + data_ptr); + break; + case Caffe::GPU: + caffe_copy(output_blobs[i]->count(), output_blobs[i]->gpu_data(), + data_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + + return mx_out; +} + +static mxArray* do_backward(const mxArray* const top_diff) { + const vector*>& output_blobs = net_->output_blobs(); + const vector*>& input_blobs = net_->input_blobs(); + if (static_cast(mxGetDimensions(top_diff)[0]) != + output_blobs.size()) { + mex_error("Invalid input size"); + } + // First, copy the output diff + for (unsigned int i = 0; i < output_blobs.size(); ++i) { + const mxArray* const elem = mxGetCell(top_diff, i); + const float* const data_ptr = + reinterpret_cast(mxGetPr(elem)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(output_blobs[i]->count(), data_ptr, + output_blobs[i]->mutable_cpu_diff()); + break; + case Caffe::GPU: + caffe_copy(output_blobs[i]->count(), data_ptr, + output_blobs[i]->mutable_gpu_diff()); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + // LOG(INFO) << "Start"; + net_->Backward(); + // LOG(INFO) << "End"; + mxArray* mx_out = mxCreateCellMatrix(input_blobs.size(), 1); + for (unsigned int i = 0; i < input_blobs.size(); ++i) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {input_blobs[i]->width(), input_blobs[i]->height(), + input_blobs[i]->channels(), input_blobs[i]->num()}; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_out, i, mx_blob); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(input_blobs[i]->count(), input_blobs[i]->cpu_diff(), data_ptr); + break; + case Caffe::GPU: + caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } // switch (Caffe::mode()) + } + + return mx_out; +} + +static mxArray* do_get_weights() { + const vector > >& layers = net_->layers(); + const vector& layer_names = net_->layer_names(); + + // Step 1: count the number of layers with weights + int num_layers = 0; + { + string prev_layer_name = ""; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + num_layers++; + } + } + } + + // Step 2: prepare output array of structures + mxArray* mx_layers; + { + const mwSize dims[2] = {num_layers, 1}; + const char* fnames[2] = {"weights", "layer_names"}; + mx_layers = mxCreateStructArray(2, dims, 2, fnames); + } + + // Step 3: copy weights into output + { + string prev_layer_name = ""; + int mx_layer_index = 0; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + + mxArray* mx_layer_cells = NULL; + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + const mwSize dims[2] = {static_cast(layer_blobs.size()), 1}; + mx_layer_cells = mxCreateCellArray(2, dims); + mxSetField(mx_layers, mx_layer_index, "weights", mx_layer_cells); + mxSetField(mx_layers, mx_layer_index, "layer_names", + mxCreateString(layer_names[i].c_str())); + mx_layer_index++; + } + + for (unsigned int j = 0; j < layer_blobs.size(); ++j) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {layer_blobs[j]->width(), layer_blobs[j]->height(), + layer_blobs[j]->channels(), layer_blobs[j]->num()}; + + mxArray* mx_weights = + mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_layer_cells, j, mx_weights); + float* weights_ptr = reinterpret_cast(mxGetPr(mx_weights)); + + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->cpu_data(), + weights_ptr); + break; + case Caffe::GPU: + caffe_copy(layer_blobs[j]->count(), layer_blobs[j]->gpu_data(), + weights_ptr); + break; + default: + mex_error("Unknown Caffe mode"); + } + } + } + } + + return mx_layers; +} + +static void do_set_weights(const mxArray* const mx_layers) { + // check input + if (!mxIsStruct(mx_layers)) { + mex_error("Expected input structure array with \"weights\" " + "and \"layer_names\" fields"); + } + const int num_layers = mxGetNumberOfElements(mx_layers); + const vector > >& layers = net_->layers(); + const vector& layer_names = net_->layer_names(); + for (int i = 0; i < num_layers; ++i) { + char* c_l_name = mxArrayToString(mxGetField(mx_layers, i, "layer_names")); + const mxArray* mx_l_weights = mxGetField(mx_layers, i, "weights"); + if (!c_l_name || !mx_l_weights || !mxIsCell(mx_l_weights)) { + // fail to find corresponding field + mex_error("Expected field \"weights\" contain " + "cells of single-precision number, " + "and field \"layer_names\" contain string"); + } + const string l_name(c_l_name); + mxFree(static_cast(c_l_name)); + // Step 1: find corresponding layer in the net_ + unsigned int ln; + for (ln = 0; ln < layer_names.size(); ++ln) { + if (l_name == layer_names[ln]) break; + } + if (ln >= layer_names.size()) { + mex_warn("Ignoring source layer " + l_name); + continue; + } + // Step 2: set layer weights + const vector > >& layer_blobs = layers[ln]->blobs(); + if (layer_blobs.size() != mxGetNumberOfElements(mx_l_weights)) { + ostringstream error_msg; + error_msg << "Layer " << l_name << " expected " + << layer_blobs.size() << " blob(s), got " + << mxGetNumberOfElements(mx_l_weights); + mex_error(error_msg.str()); + } + for (unsigned int j = 0; j < layer_blobs.size(); ++j) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + const mxArray* mx_weights = mxGetCell(mx_l_weights, j); + if (!mxIsSingle(mx_weights)) { + mex_error("MatCaffe require single-precision float point data"); + } + const int num_dims = mxGetNumberOfDimensions(mx_weights); + if (num_dims > 4) { + ostringstream error_msg; + error_msg << "Expected input blob has at most 4 dimensions, got " + << num_dims; + mex_error(error_msg.str()); + } + const mwSize *dims = mxGetDimensions(mx_weights); + const int width = dims[0]; + const int height = dims[1]; + const int channels = num_dims > 2 ? dims[2] : 1; + const int num = num_dims > 3 ? dims[3] : 1; + if (layer_blobs[j]->width() != width) { + ostringstream error_msg; + error_msg << "Expected blob " << j << " in layer " << l_name + << " has width = " << layer_blobs[j]->width() + << ", got " << width; + mex_error(error_msg.str()); + } + if (layer_blobs[j]->height() != height) { + ostringstream error_msg; + error_msg << "Expected blob " << j << " in layer " << l_name + << " has height = " << layer_blobs[j]->height() + << ", got " << height; + mex_error(error_msg.str()); + } + if (layer_blobs[j]->channels() != channels) { + ostringstream error_msg; + error_msg << "Expected blob " << j << " in layer " << l_name + << " has channels = " << layer_blobs[j]->channels() + << ", got " << channels; + mex_error(error_msg.str()); + } + if (layer_blobs[j]->num() != num) { + ostringstream error_msg; + error_msg << "Expected blob " << j << " in layer " << l_name + << " has width = " << layer_blobs[j]->num() + << ", got " << num; + mex_error(error_msg.str()); + } + + const float* weights_ptr = + reinterpret_cast(mxGetPr(mx_weights)); + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_copy(layer_blobs[j]->count(), weights_ptr, + layer_blobs[j]->mutable_cpu_data()); + break; + case Caffe::GPU: + caffe_copy(layer_blobs[j]->count(), weights_ptr, + layer_blobs[j]->mutable_gpu_data()); + break; + default: + mex_error("Unknown Caffe mode"); + } + } + } +} + +static void get_weights(MEX_ARGS) { + if (!net_) { + mex_error("Init net before get weights"); + } + plhs[0] = do_get_weights(); +} + +static void set_weights(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + do_set_weights(prhs[0]); +} + +static void set_mode_cpu(MEX_ARGS) { + Caffe::set_mode(Caffe::CPU); +} + +static void set_mode_gpu(MEX_ARGS) { + Caffe::set_mode(Caffe::GPU); +} + +static void set_phase_train(MEX_ARGS) { + Caffe::set_phase(Caffe::TRAIN); +} + +static void set_phase_test(MEX_ARGS) { + Caffe::set_phase(Caffe::TEST); +} + +static void set_device(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + int device_id = static_cast(mxGetScalar(prhs[0])); + Caffe::SetDevice(device_id); +} + +static void get_init_key(MEX_ARGS) { + plhs[0] = mxCreateDoubleScalar(init_key); +} + +static void init(MEX_ARGS) { + if (nrhs != 2) { + ostringstream error_msg; + error_msg << "Expected 2 arguments, got " << nrhs; + mex_error(error_msg.str()); + } + + char* param_file = mxArrayToString(prhs[0]); + char* model_file = mxArrayToString(prhs[1]); + + net_.reset(new Net(string(param_file))); + net_->CopyTrainedLayersFrom(string(model_file)); + + mxFree(param_file); + mxFree(model_file); + + init_key = random(); // NOLINT(caffe/random_fn) + + if (nlhs == 1) { + plhs[0] = mxCreateDoubleScalar(init_key); + } +} + +static void reset(MEX_ARGS) { + if (net_) { + net_.reset(); + init_key = -2; + LOG(INFO) << "Network reset, call init before use it again"; + } +} + +// save the network weights to binary proto +static void save(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + if (!net_) { + mex_error("Init net before save it"); + } + char* c_model_file = mxArrayToString(prhs[0]); + if (!c_model_file) { + mex_error("Expected string input for model name"); + } + string model_file(c_model_file); + mxFree(static_cast(c_model_file)); + + NetParameter net_param; + net_->ToProto(&net_param, false); + WriteProtoToBinaryFile(net_param, model_file); +} + +static void forward(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + plhs[0] = do_forward(prhs[0]); +} + +static void conv_forward(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + plhs[0] = do_forward(prhs[0], true); +} + +static void backward(MEX_ARGS) { + if (nrhs != 1) { + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); + } + + plhs[0] = do_backward(prhs[0]); +} + +static void is_initialized(MEX_ARGS) { + if (!net_) { + plhs[0] = mxCreateDoubleScalar(0); + } else { + plhs[0] = mxCreateDoubleScalar(1); + } +} + +static void read_mean(MEX_ARGS) { + if (nrhs != 1) { + mexErrMsgTxt("Usage: caffe('read_mean', 'path_to_binary_mean_file'"); + return; + } + const string& mean_file = mxArrayToString(prhs[0]); + Blob data_mean; + LOG(INFO) << "Loading mean file from" << mean_file; + BlobProto blob_proto; + bool result = ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto); + if (!result) { + mexErrMsgTxt("Couldn't read the file"); + return; + } + data_mean.FromProto(blob_proto); + mwSize dims[4] = {data_mean.width(), data_mean.height(), + data_mean.channels(), data_mean.num() }; + mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + float* data_ptr = reinterpret_cast(mxGetPr(mx_blob)); + caffe_copy(data_mean.count(), data_mean.cpu_data(), data_ptr); + mexWarnMsgTxt("Remember that Caffe saves in [width, height, channels]" + " format and channels are also BGR!"); + plhs[0] = mx_blob; +} + +/** ----------------------------------------------------------------- + ** Available commands. + **/ +struct handler_registry { + string cmd; + void (*func)(MEX_ARGS); +}; + +static handler_registry handlers[] = { + // Public API functions + { "forward", forward }, + { "conv_forward", conv_forward }, + { "backward", backward }, + { "init", init }, + { "is_initialized", is_initialized }, + { "set_mode_cpu", set_mode_cpu }, + { "set_mode_gpu", set_mode_gpu }, + { "set_phase_train", set_phase_train }, + { "set_phase_test", set_phase_test }, + { "set_device", set_device }, + { "get_weights", get_weights }, + { "set_weights", set_weights }, + { "get_init_key", get_init_key }, + { "reset", reset }, + { "save", save }, + { "read_mean", read_mean }, + // The end. + { "END", NULL }, +}; + + +/** ----------------------------------------------------------------- + ** matlab entry point: caffe(api_command, arg1, arg2, ...) + **/ +void mexFunction(MEX_ARGS) { + mexLock(); // Avoid clearing the mex file. + if (nrhs == 0) { + mex_error("No API command given"); + return; + } + + { // Handle input command + char *cmd = mxArrayToString(prhs[0]); + bool dispatched = false; + // Dispatch to cmd handler + for (int i = 0; handlers[i].func != NULL; i++) { + if (handlers[i].cmd.compare(cmd) == 0) { + handlers[i].func(nlhs, plhs, nrhs-1, prhs+1); + dispatched = true; + break; + } + } + if (!dispatched) { + ostringstream error_msg; + error_msg << "Unknown command '" << cmd << "'"; + mex_error(error_msg.str()); + } + mxFree(cmd); + } +} + diff --git a/caffe-crfrnn/matlab/caffe/prepare_batch.m b/caffe-crfrnn/matlab/caffe/prepare_batch.m new file mode 100644 index 00000000..345c8eb5 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/prepare_batch.m @@ -0,0 +1,41 @@ +% ------------------------------------------------------------------------ +function images = prepare_batch(image_files,IMAGE_MEAN,batch_size) +% ------------------------------------------------------------------------ +if nargin < 2 + d = load('ilsvrc_2012_mean'); + IMAGE_MEAN = d.image_mean; +end +num_images = length(image_files); +if nargin < 3 + batch_size = num_images; +end + +IMAGE_DIM = 256; +CROPPED_DIM = 227; +indices = [0 IMAGE_DIM-CROPPED_DIM] + 1; +center = floor(indices(2) / 2)+1; + +num_images = length(image_files); +images = zeros(CROPPED_DIM,CROPPED_DIM,3,batch_size,'single'); + +parfor i=1:num_images + % read file + fprintf('%c Preparing %s\n',13,image_files{i}); + try + im = imread(image_files{i}); + % resize to fixed input size + im = single(im); + im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear'); + % Transform GRAY to RGB + if size(im,3) == 1 + im = cat(3,im,im,im); + end + % permute from RGB to BGR (IMAGE_MEAN is already BGR) + im = im(:,:,[3 2 1]) - IMAGE_MEAN; + % Crop the center of the image + images(:,:,:,i) = permute(im(center:center+CROPPED_DIM-1,... + center:center+CROPPED_DIM-1,:),[2 1 3]); + catch + warning('Problems with file',image_files{i}); + end +end \ No newline at end of file diff --git a/caffe-crfrnn/matlab/caffe/print_cell.m b/caffe-crfrnn/matlab/caffe/print_cell.m new file mode 100644 index 00000000..864340d4 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/print_cell.m @@ -0,0 +1,42 @@ +function res=print_cell(input,file,linesep,cellsep) +assert(iscell(input),'The input should be a cell') +if nargin < 4 + cellsep = '\t'; +end +if nargin < 3 + linesep = '\n'; +end +if exist('file','var') && ~isempty(file) + %% + fid = fopen(file,'w'); + for l=1:length(input) + if iscell(input{l}) + for i=1:length(input{l}) + fprintf(fid,['%s' cellsep],input{l}{i}); + end + fprintf(fid,linesep); + else + if size(input,2) > 1 + for i=1:size(input,2) + fprintf(fid,'%s ',input{l,i}); + end + fprintf(fid,linesep); + else + fprintf(fid,['%s' linesep],input{l}); + end + end + end + fclose(fid); +else + res = ''; + for l=1:length(input) + if iscell(input{l}) + for i=1:length(input{l}) + res = [res sprintf([cellsep{1} '%s' cellsep{2}],input{l}{i})]; + end + res = [res sprintf(linesep)]; + else + res = [res sprintf(['%s' linesep],input{l}(:))]; + end + end +end \ No newline at end of file diff --git a/caffe-crfrnn/matlab/caffe/read_cell.m b/caffe-crfrnn/matlab/caffe/read_cell.m new file mode 100644 index 00000000..19831167 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/read_cell.m @@ -0,0 +1,21 @@ +function res=read_cell(filename,linesep,cellsep) +if nargin < 2, linesep='\n'; end +if nargin < 3, cellsep = '\t'; end +if exist(filename,'file') + fid = fopen(filename); +else + % Assume that filename is either a file ide or a string + fid = filename; +end + +fileLines = textscan(fid,'%s','delimiter',linesep,'BufSize',100000); + +fileLines = fileLines{1}; + +if regexp(fileLines{1},cellsep,'once') + fileLines = regexprep(fileLines,['^' cellsep '|' cellsep '$'],''); + res = regexp(fileLines,cellsep,'split'); + res = cell2matcell(res); +else + res = fileLines; +end diff --git a/caffe-crfrnn/matlab/caffe/tvg_matcaffe_init.m b/caffe-crfrnn/matlab/caffe/tvg_matcaffe_init.m new file mode 100644 index 00000000..8d508417 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/tvg_matcaffe_init.m @@ -0,0 +1,41 @@ +function tvg_matcaffe_init(use_gpu, model_def_file, model_file) +% matcaffe_init(model_def_file, model_file, use_gpu) +% Initilize matcaffe wrapper + +if nargin < 1 + error('Missing argument use_gpu'); +end + +if nargin < 2 || isempty(model_def_file) + error('Missing argument model_def_file'); +end + +if nargin < 3 || isempty(model_file) + error('Missing argument model_file'); +end + + +if caffe('is_initialized') == 0 + if exist(model_file, 'file') ~= 2 + error('You need a network model file'); + end + if exist(model_def_file,'file') ~= 2 + error('You need the network prototxt definition'); + end + caffe('init', model_def_file, model_file) +end +fprintf('Done with init\n'); + +% set to use GPU or CPU +if use_gpu + fprintf('Using GPU Mode\n'); + caffe('set_mode_gpu'); +else + fprintf('Using CPU Mode\n'); + caffe('set_mode_cpu'); +end +fprintf('Done with set_mode\n'); + +% put into test mode +caffe('set_phase_test'); +fprintf('Done with set_phase_test\n'); diff --git a/caffe-crfrnn/matlab/caffe/tvg_prepare_image_fixed.m b/caffe-crfrnn/matlab/caffe/tvg_prepare_image_fixed.m new file mode 100644 index 00000000..6fecf7f5 --- /dev/null +++ b/caffe-crfrnn/matlab/caffe/tvg_prepare_image_fixed.m @@ -0,0 +1,20 @@ +function images = tvg_prepare_image_fixed(im) + +INPUT_DIM = 500; + +% mean BGR pixel +mean_pix = [103.939, 116.779, 123.68]; + +im = single(im); +% RGB -> BGR +im = im(:, :, [3 2 1]); + +% mean BGR pixel subtraction +for c = 1:3 + im(:, :, c) = im(:, :, c) - mean_pix(c); +end + +images = zeros(INPUT_DIM, INPUT_DIM, 3, 1, 'single'); +[h, w, ~] = size(im); + +images(1:w, 1:h, :, 1) = permute(im, [2 1 3]); diff --git a/caffe-crfrnn/models/bvlc_alexnet/deploy.prototxt b/caffe-crfrnn/models/bvlc_alexnet/deploy.prototxt new file mode 100644 index 00000000..d010753f --- /dev/null +++ b/caffe-crfrnn/models/bvlc_alexnet/deploy.prototxt @@ -0,0 +1,244 @@ +name: "AlexNet" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 227 +input_dim: 227 +layers { + name: "conv1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + } + bottom: "data" + top: "conv1" +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "norm1" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } + bottom: "conv1" + top: "norm1" +} +layers { + name: "pool1" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + bottom: "norm1" + top: "pool1" +} +layers { + name: "conv2" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + } + bottom: "pool1" + top: "conv2" +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "norm2" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } + bottom: "conv2" + top: "norm2" +} +layers { + name: "pool2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + bottom: "norm2" + top: "pool2" +} +layers { + name: "conv3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + } + bottom: "pool2" + top: "conv3" +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + } + bottom: "conv3" + top: "conv4" +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + } + bottom: "conv4" + top: "conv5" +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + bottom: "conv5" + top: "pool5" +} +layers { + name: "fc6" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + } + bottom: "pool5" + top: "fc6" +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + dropout_param { + dropout_ratio: 0.5 + } + bottom: "fc6" + top: "fc6" +} +layers { + name: "fc7" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + } + bottom: "fc6" + top: "fc7" +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + dropout_param { + dropout_ratio: 0.5 + } + bottom: "fc7" + top: "fc7" +} +layers { + name: "fc8" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + } + bottom: "fc7" + top: "fc8" +} +layers { + name: "prob" + type: SOFTMAX + bottom: "fc8" + top: "prob" +} diff --git a/caffe-crfrnn/models/bvlc_alexnet/readme.md b/caffe-crfrnn/models/bvlc_alexnet/readme.md new file mode 100644 index 00000000..c25fd4f8 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_alexnet/readme.md @@ -0,0 +1,27 @@ +--- +name: BVLC AlexNet Model +caffemodel: bvlc_alexnet.caffemodel +caffemodel_url: http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel +license: non-commercial +sha1: 9116a64c0fbe4459d18f4bb6b56d647b63920377 +caffe_commit: 709dc15af4a06bebda027c1eb2b3f3e3375d5077 +--- + +This model is a replication of the model described in the [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) publication. + +Differences: +- not training with the relighting data-augmentation; +- initializing non-zero biases to 0.1 instead of 1 (found necessary for training, as initialization to 1 gave flat loss). + +The bundled model is the iteration 360,000 snapshot. +The best validation performance during training was iteration 358,000 with validation accuracy 57.258% and loss 1.83948. +This model obtains a top-1 accuracy 57.1% and a top-5 accuracy 80.2% on the validation set, using just the center crop. +(Using the average of 10 crops, (4 + 1 center) * 2 mirror, should obtain a bit higher accuracy.) + +This model was trained by Evan Shelhamer @shelhamer + +## License + +The data used to train this model comes from the ImageNet project, which distributes its database to researchers who agree to a following term of access: +"Researcher shall use the Database only for non-commercial research and educational purposes." +Accordingly, this model is distributed under a non-commercial license. diff --git a/caffe-crfrnn/models/bvlc_alexnet/solver.prototxt b/caffe-crfrnn/models/bvlc_alexnet/solver.prototxt new file mode 100644 index 00000000..129265e6 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_alexnet/solver.prototxt @@ -0,0 +1,14 @@ +net: "models/bvlc_alexnet/train_val.prototxt" +test_iter: 1000 +test_interval: 1000 +base_lr: 0.01 +lr_policy: "step" +gamma: 0.1 +stepsize: 100000 +display: 20 +max_iter: 450000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "models/bvlc_alexnet/caffe_alexnet_train" +solver_mode: GPU diff --git a/caffe-crfrnn/models/bvlc_alexnet/train_val.prototxt b/caffe-crfrnn/models/bvlc_alexnet/train_val.prototxt new file mode 100644 index 00000000..717b6fa4 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_alexnet/train_val.prototxt @@ -0,0 +1,346 @@ +name: "AlexNet" +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/imagenet/ilsvrc12_train_lmdb" + backend: LMDB + batch_size: 256 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } + include: { phase: TRAIN } +} +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/imagenet/ilsvrc12_val_lmdb" + backend: LMDB + batch_size: 50 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "norm1" + type: LRN + bottom: "conv1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "pool1" + type: POOLING + bottom: "norm1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "pool1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "norm2" + type: LRN + bottom: "conv2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "pool2" + type: POOLING + bottom: "norm2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "pool2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 0.1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "fc8" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "fc8" + bottom: "label" + top: "loss" +} diff --git a/caffe-crfrnn/models/bvlc_googlenet/deploy.prototxt b/caffe-crfrnn/models/bvlc_googlenet/deploy.prototxt new file mode 100644 index 00000000..e31a4c9c --- /dev/null +++ b/caffe-crfrnn/models/bvlc_googlenet/deploy.prototxt @@ -0,0 +1,1924 @@ +name: "GoogleNet" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 224 +input_dim: 224 +layers { + bottom: "data" + top: "conv1/7x7_s2" + name: "conv1/7x7_s2" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 3 + kernel_size: 7 + stride: 2 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv1/7x7_s2" + top: "conv1/7x7_s2" + name: "conv1/relu_7x7" + type: RELU +} +layers { + bottom: "conv1/7x7_s2" + top: "pool1/3x3_s2" + name: "pool1/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool1/3x3_s2" + top: "pool1/norm1" + name: "pool1/norm1" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + bottom: "pool1/norm1" + top: "conv2/3x3_reduce" + name: "conv2/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv2/3x3_reduce" + top: "conv2/3x3_reduce" + name: "conv2/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "conv2/3x3_reduce" + top: "conv2/3x3" + name: "conv2/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv2/3x3" + top: "conv2/3x3" + name: "conv2/relu_3x3" + type: RELU +} +layers { + bottom: "conv2/3x3" + top: "conv2/norm2" + name: "conv2/norm2" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + bottom: "conv2/norm2" + top: "pool2/3x3_s2" + name: "pool2/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/1x1" + name: "inception_3a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/1x1" + top: "inception_3a/1x1" + name: "inception_3a/relu_1x1" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/3x3_reduce" + name: "inception_3a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/3x3_reduce" + top: "inception_3a/3x3_reduce" + name: "inception_3a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_3a/3x3_reduce" + top: "inception_3a/3x3" + name: "inception_3a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/3x3" + top: "inception_3a/3x3" + name: "inception_3a/relu_3x3" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/5x5_reduce" + name: "inception_3a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 16 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/5x5_reduce" + top: "inception_3a/5x5_reduce" + name: "inception_3a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_3a/5x5_reduce" + top: "inception_3a/5x5" + name: "inception_3a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/5x5" + top: "inception_3a/5x5" + name: "inception_3a/relu_5x5" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/pool" + name: "inception_3a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_3a/pool" + top: "inception_3a/pool_proj" + name: "inception_3a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/pool_proj" + top: "inception_3a/pool_proj" + name: "inception_3a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_3a/1x1" + bottom: "inception_3a/3x3" + bottom: "inception_3a/5x5" + bottom: "inception_3a/pool_proj" + top: "inception_3a/output" + name: "inception_3a/output" + type: CONCAT +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/1x1" + name: "inception_3b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/1x1" + top: "inception_3b/1x1" + name: "inception_3b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/3x3_reduce" + name: "inception_3b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/3x3_reduce" + top: "inception_3b/3x3_reduce" + name: "inception_3b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_3b/3x3_reduce" + top: "inception_3b/3x3" + name: "inception_3b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/3x3" + top: "inception_3b/3x3" + name: "inception_3b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/5x5_reduce" + name: "inception_3b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/5x5_reduce" + top: "inception_3b/5x5_reduce" + name: "inception_3b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_3b/5x5_reduce" + top: "inception_3b/5x5" + name: "inception_3b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/5x5" + top: "inception_3b/5x5" + name: "inception_3b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/pool" + name: "inception_3b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_3b/pool" + top: "inception_3b/pool_proj" + name: "inception_3b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/pool_proj" + top: "inception_3b/pool_proj" + name: "inception_3b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_3b/1x1" + bottom: "inception_3b/3x3" + bottom: "inception_3b/5x5" + bottom: "inception_3b/pool_proj" + top: "inception_3b/output" + name: "inception_3b/output" + type: CONCAT +} +layers { + bottom: "inception_3b/output" + top: "pool3/3x3_s2" + name: "pool3/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/1x1" + name: "inception_4a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/1x1" + top: "inception_4a/1x1" + name: "inception_4a/relu_1x1" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/3x3_reduce" + name: "inception_4a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/3x3_reduce" + top: "inception_4a/3x3_reduce" + name: "inception_4a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4a/3x3_reduce" + top: "inception_4a/3x3" + name: "inception_4a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 208 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/3x3" + top: "inception_4a/3x3" + name: "inception_4a/relu_3x3" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/5x5_reduce" + name: "inception_4a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 16 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/5x5_reduce" + top: "inception_4a/5x5_reduce" + name: "inception_4a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4a/5x5_reduce" + top: "inception_4a/5x5" + name: "inception_4a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 48 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/5x5" + top: "inception_4a/5x5" + name: "inception_4a/relu_5x5" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/pool" + name: "inception_4a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4a/pool" + top: "inception_4a/pool_proj" + name: "inception_4a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/pool_proj" + top: "inception_4a/pool_proj" + name: "inception_4a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4a/1x1" + bottom: "inception_4a/3x3" + bottom: "inception_4a/5x5" + bottom: "inception_4a/pool_proj" + top: "inception_4a/output" + name: "inception_4a/output" + type: CONCAT +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/1x1" + name: "inception_4b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/1x1" + top: "inception_4b/1x1" + name: "inception_4b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/3x3_reduce" + name: "inception_4b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 112 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/3x3_reduce" + top: "inception_4b/3x3_reduce" + name: "inception_4b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4b/3x3_reduce" + top: "inception_4b/3x3" + name: "inception_4b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 224 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/3x3" + top: "inception_4b/3x3" + name: "inception_4b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/5x5_reduce" + name: "inception_4b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 24 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/5x5_reduce" + top: "inception_4b/5x5_reduce" + name: "inception_4b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4b/5x5_reduce" + top: "inception_4b/5x5" + name: "inception_4b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/5x5" + top: "inception_4b/5x5" + name: "inception_4b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/pool" + name: "inception_4b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4b/pool" + top: "inception_4b/pool_proj" + name: "inception_4b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/pool_proj" + top: "inception_4b/pool_proj" + name: "inception_4b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4b/1x1" + bottom: "inception_4b/3x3" + bottom: "inception_4b/5x5" + bottom: "inception_4b/pool_proj" + top: "inception_4b/output" + name: "inception_4b/output" + type: CONCAT +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/1x1" + name: "inception_4c/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/1x1" + top: "inception_4c/1x1" + name: "inception_4c/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/3x3_reduce" + name: "inception_4c/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/3x3_reduce" + top: "inception_4c/3x3_reduce" + name: "inception_4c/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4c/3x3_reduce" + top: "inception_4c/3x3" + name: "inception_4c/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/3x3" + top: "inception_4c/3x3" + name: "inception_4c/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/5x5_reduce" + name: "inception_4c/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 24 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/5x5_reduce" + top: "inception_4c/5x5_reduce" + name: "inception_4c/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4c/5x5_reduce" + top: "inception_4c/5x5" + name: "inception_4c/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/5x5" + top: "inception_4c/5x5" + name: "inception_4c/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/pool" + name: "inception_4c/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4c/pool" + top: "inception_4c/pool_proj" + name: "inception_4c/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/pool_proj" + top: "inception_4c/pool_proj" + name: "inception_4c/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4c/1x1" + bottom: "inception_4c/3x3" + bottom: "inception_4c/5x5" + bottom: "inception_4c/pool_proj" + top: "inception_4c/output" + name: "inception_4c/output" + type: CONCAT +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/1x1" + name: "inception_4d/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 112 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/1x1" + top: "inception_4d/1x1" + name: "inception_4d/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/3x3_reduce" + name: "inception_4d/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 144 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/3x3_reduce" + top: "inception_4d/3x3_reduce" + name: "inception_4d/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4d/3x3_reduce" + top: "inception_4d/3x3" + name: "inception_4d/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 288 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/3x3" + top: "inception_4d/3x3" + name: "inception_4d/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/5x5_reduce" + name: "inception_4d/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/5x5_reduce" + top: "inception_4d/5x5_reduce" + name: "inception_4d/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4d/5x5_reduce" + top: "inception_4d/5x5" + name: "inception_4d/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/5x5" + top: "inception_4d/5x5" + name: "inception_4d/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/pool" + name: "inception_4d/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4d/pool" + top: "inception_4d/pool_proj" + name: "inception_4d/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/pool_proj" + top: "inception_4d/pool_proj" + name: "inception_4d/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4d/1x1" + bottom: "inception_4d/3x3" + bottom: "inception_4d/5x5" + bottom: "inception_4d/pool_proj" + top: "inception_4d/output" + name: "inception_4d/output" + type: CONCAT +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/1x1" + name: "inception_4e/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/1x1" + top: "inception_4e/1x1" + name: "inception_4e/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/3x3_reduce" + name: "inception_4e/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/3x3_reduce" + top: "inception_4e/3x3_reduce" + name: "inception_4e/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4e/3x3_reduce" + top: "inception_4e/3x3" + name: "inception_4e/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 320 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/3x3" + top: "inception_4e/3x3" + name: "inception_4e/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/5x5_reduce" + name: "inception_4e/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/5x5_reduce" + top: "inception_4e/5x5_reduce" + name: "inception_4e/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4e/5x5_reduce" + top: "inception_4e/5x5" + name: "inception_4e/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/5x5" + top: "inception_4e/5x5" + name: "inception_4e/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/pool" + name: "inception_4e/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4e/pool" + top: "inception_4e/pool_proj" + name: "inception_4e/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/pool_proj" + top: "inception_4e/pool_proj" + name: "inception_4e/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4e/1x1" + bottom: "inception_4e/3x3" + bottom: "inception_4e/5x5" + bottom: "inception_4e/pool_proj" + top: "inception_4e/output" + name: "inception_4e/output" + type: CONCAT +} +layers { + bottom: "inception_4e/output" + top: "pool4/3x3_s2" + name: "pool4/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/1x1" + name: "inception_5a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/1x1" + top: "inception_5a/1x1" + name: "inception_5a/relu_1x1" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/3x3_reduce" + name: "inception_5a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/3x3_reduce" + top: "inception_5a/3x3_reduce" + name: "inception_5a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_5a/3x3_reduce" + top: "inception_5a/3x3" + name: "inception_5a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 320 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/3x3" + top: "inception_5a/3x3" + name: "inception_5a/relu_3x3" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/5x5_reduce" + name: "inception_5a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/5x5_reduce" + top: "inception_5a/5x5_reduce" + name: "inception_5a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_5a/5x5_reduce" + top: "inception_5a/5x5" + name: "inception_5a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/5x5" + top: "inception_5a/5x5" + name: "inception_5a/relu_5x5" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/pool" + name: "inception_5a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_5a/pool" + top: "inception_5a/pool_proj" + name: "inception_5a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/pool_proj" + top: "inception_5a/pool_proj" + name: "inception_5a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_5a/1x1" + bottom: "inception_5a/3x3" + bottom: "inception_5a/5x5" + bottom: "inception_5a/pool_proj" + top: "inception_5a/output" + name: "inception_5a/output" + type: CONCAT +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/1x1" + name: "inception_5b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/1x1" + top: "inception_5b/1x1" + name: "inception_5b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/3x3_reduce" + name: "inception_5b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/3x3_reduce" + top: "inception_5b/3x3_reduce" + name: "inception_5b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_5b/3x3_reduce" + top: "inception_5b/3x3" + name: "inception_5b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/3x3" + top: "inception_5b/3x3" + name: "inception_5b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/5x5_reduce" + name: "inception_5b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 48 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/5x5_reduce" + top: "inception_5b/5x5_reduce" + name: "inception_5b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_5b/5x5_reduce" + top: "inception_5b/5x5" + name: "inception_5b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/5x5" + top: "inception_5b/5x5" + name: "inception_5b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/pool" + name: "inception_5b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_5b/pool" + top: "inception_5b/pool_proj" + name: "inception_5b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/pool_proj" + top: "inception_5b/pool_proj" + name: "inception_5b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_5b/1x1" + bottom: "inception_5b/3x3" + bottom: "inception_5b/5x5" + bottom: "inception_5b/pool_proj" + top: "inception_5b/output" + name: "inception_5b/output" + type: CONCAT +} +layers { + bottom: "inception_5b/output" + top: "pool5/7x7_s1" + name: "pool5/7x7_s1" + type: POOLING + pooling_param { + pool: AVE + kernel_size: 7 + stride: 1 + } +} +layers { + bottom: "pool5/7x7_s1" + top: "pool5/7x7_s1" + name: "pool5/drop_7x7_s1" + type: DROPOUT + dropout_param { + dropout_ratio: 0.4 + } +} +layers { + bottom: "pool5/7x7_s1" + top: "loss3/classifier" + name: "loss3/classifier" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "prob" + type: SOFTMAX + bottom: "loss3/classifier" + top: "prob" +} \ No newline at end of file diff --git a/caffe-crfrnn/models/bvlc_googlenet/quick_solver.prototxt b/caffe-crfrnn/models/bvlc_googlenet/quick_solver.prototxt new file mode 100644 index 00000000..5d2f7ee7 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_googlenet/quick_solver.prototxt @@ -0,0 +1,15 @@ +net: "models/bvlc_googlenet/train_val.prototxt" +test_iter: 1000 +test_interval: 4000 +test_initialization: false +display: 40 +average_loss: 40 +base_lr: 0.01 +lr_policy: "poly" +power: 0.5 +max_iter: 2400000 +momentum: 0.9 +weight_decay: 0.0002 +snapshot: 40000 +snapshot_prefix: "models/bvlc_googlenet/bvlc_googlenet_quick" +solver_mode: GPU diff --git a/caffe-crfrnn/models/bvlc_googlenet/readme.md b/caffe-crfrnn/models/bvlc_googlenet/readme.md new file mode 100644 index 00000000..8a3bbec4 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_googlenet/readme.md @@ -0,0 +1,35 @@ +--- +name: BVLC GoogleNet Model +caffemodel: bvlc_googlenet.caffemodel +caffemodel_url: http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel +license: non-commercial +sha1: 405fc5acd08a3bb12de8ee5e23a96bec22f08204 +caffe_commit: bc614d1bd91896e3faceaf40b23b72dab47d44f5 +gist_id: 866e2aa1fd707b89b913 +--- + +This model is a replication of the model described in the [GoogleNet](http://arxiv.org/abs/1409.4842) publication. We would like to thank Christian Szegedy for all his help in the replication of GoogleNet model. + +Differences: +- not training with the relighting data-augmentation; +- not training with the scale or aspect-ratio data-augmentation; +- uses "xavier" to initialize the weights instead of "gaussian"; +- quick_solver.prototxt uses a different learning rate decay policy than the original solver.prototxt, that allows a much faster training (60 epochs vs 250 epochs); + +The bundled model is the iteration 2,400,000 snapshot (60 epochs) using quick_solver.prototxt + +This bundled model obtains a top-1 accuracy 68.7% (31.3% error) and a top-5 accuracy 88.9% (11.1% error) on the validation set, using just the center crop. +(Using the average of 10 crops, (4 + 1 center) * 2 mirror, should obtain a bit higher accuracy.) + +Timings for bvlc_googlenet with cuDNN using batch_size:128 on a K40c: + - Average Forward pass: 562.841 ms. + - Average Backward pass: 1123.84 ms. + - Average Forward-Backward: 1688.8 ms. + +This model was trained by Sergio Guadarrama @sguada + +## License + +The data used to train this model comes from the ImageNet project, which distributes its database to researchers who agree to a following term of access: +"Researcher shall use the Database only for non-commercial research and educational purposes." +Accordingly, this model is distributed under a non-commercial license. diff --git a/caffe-crfrnn/models/bvlc_googlenet/solver.prototxt b/caffe-crfrnn/models/bvlc_googlenet/solver.prototxt new file mode 100644 index 00000000..d7d17881 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_googlenet/solver.prototxt @@ -0,0 +1,16 @@ +net: "models/bvlc_googlenet/train_val.prototxt" +test_iter: 1000 +test_interval: 4000 +test_initialization: false +display: 40 +average_loss: 40 +base_lr: 0.01 +lr_policy: "step" +stepsize: 320000 +gamma: 0.96 +max_iter: 10000000 +momentum: 0.9 +weight_decay: 0.0002 +snapshot: 40000 +snapshot_prefix: "models/bvlc_googlenet/bvlc_googlenet" +solver_mode: GPU diff --git a/caffe-crfrnn/models/bvlc_googlenet/train_val.prototxt b/caffe-crfrnn/models/bvlc_googlenet/train_val.prototxt new file mode 100644 index 00000000..cd8f38ab --- /dev/null +++ b/caffe-crfrnn/models/bvlc_googlenet/train_val.prototxt @@ -0,0 +1,2240 @@ +name: "GoogleNet" +layers { + top: "data" + top: "label" + name: "data" + type: DATA + data_param { + source: "examples/imagenet/ilsvrc12_train_lmdb" + batch_size: 32 + backend: LMDB + } + include { + phase: TRAIN + } + transform_param { + mirror: true + crop_size: 224 + mean_value: 104 + mean_value: 117 + mean_value: 123 + } +} +layers { + top: "data" + top: "label" + name: "data" + type: DATA + data_param { + source: "examples/imagenet/ilsvrc12_val_lmdb" + batch_size: 50 + backend: LMDB + } + include { + phase: TEST + } + transform_param { + mirror: false + crop_size: 224 + mean_value: 104 + mean_value: 117 + mean_value: 123 + } +} +layers { + bottom: "data" + top: "conv1/7x7_s2" + name: "conv1/7x7_s2" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 3 + kernel_size: 7 + stride: 2 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv1/7x7_s2" + top: "conv1/7x7_s2" + name: "conv1/relu_7x7" + type: RELU +} +layers { + bottom: "conv1/7x7_s2" + top: "pool1/3x3_s2" + name: "pool1/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool1/3x3_s2" + top: "pool1/norm1" + name: "pool1/norm1" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + bottom: "pool1/norm1" + top: "conv2/3x3_reduce" + name: "conv2/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv2/3x3_reduce" + top: "conv2/3x3_reduce" + name: "conv2/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "conv2/3x3_reduce" + top: "conv2/3x3" + name: "conv2/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "conv2/3x3" + top: "conv2/3x3" + name: "conv2/relu_3x3" + type: RELU +} +layers { + bottom: "conv2/3x3" + top: "conv2/norm2" + name: "conv2/norm2" + type: LRN + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + bottom: "conv2/norm2" + top: "pool2/3x3_s2" + name: "pool2/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/1x1" + name: "inception_3a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/1x1" + top: "inception_3a/1x1" + name: "inception_3a/relu_1x1" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/3x3_reduce" + name: "inception_3a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/3x3_reduce" + top: "inception_3a/3x3_reduce" + name: "inception_3a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_3a/3x3_reduce" + top: "inception_3a/3x3" + name: "inception_3a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/3x3" + top: "inception_3a/3x3" + name: "inception_3a/relu_3x3" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/5x5_reduce" + name: "inception_3a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 16 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/5x5_reduce" + top: "inception_3a/5x5_reduce" + name: "inception_3a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_3a/5x5_reduce" + top: "inception_3a/5x5" + name: "inception_3a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/5x5" + top: "inception_3a/5x5" + name: "inception_3a/relu_5x5" + type: RELU +} +layers { + bottom: "pool2/3x3_s2" + top: "inception_3a/pool" + name: "inception_3a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_3a/pool" + top: "inception_3a/pool_proj" + name: "inception_3a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3a/pool_proj" + top: "inception_3a/pool_proj" + name: "inception_3a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_3a/1x1" + bottom: "inception_3a/3x3" + bottom: "inception_3a/5x5" + bottom: "inception_3a/pool_proj" + top: "inception_3a/output" + name: "inception_3a/output" + type: CONCAT +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/1x1" + name: "inception_3b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/1x1" + top: "inception_3b/1x1" + name: "inception_3b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/3x3_reduce" + name: "inception_3b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/3x3_reduce" + top: "inception_3b/3x3_reduce" + name: "inception_3b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_3b/3x3_reduce" + top: "inception_3b/3x3" + name: "inception_3b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/3x3" + top: "inception_3b/3x3" + name: "inception_3b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/5x5_reduce" + name: "inception_3b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/5x5_reduce" + top: "inception_3b/5x5_reduce" + name: "inception_3b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_3b/5x5_reduce" + top: "inception_3b/5x5" + name: "inception_3b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/5x5" + top: "inception_3b/5x5" + name: "inception_3b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_3a/output" + top: "inception_3b/pool" + name: "inception_3b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_3b/pool" + top: "inception_3b/pool_proj" + name: "inception_3b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_3b/pool_proj" + top: "inception_3b/pool_proj" + name: "inception_3b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_3b/1x1" + bottom: "inception_3b/3x3" + bottom: "inception_3b/5x5" + bottom: "inception_3b/pool_proj" + top: "inception_3b/output" + name: "inception_3b/output" + type: CONCAT +} +layers { + bottom: "inception_3b/output" + top: "pool3/3x3_s2" + name: "pool3/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/1x1" + name: "inception_4a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/1x1" + top: "inception_4a/1x1" + name: "inception_4a/relu_1x1" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/3x3_reduce" + name: "inception_4a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/3x3_reduce" + top: "inception_4a/3x3_reduce" + name: "inception_4a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4a/3x3_reduce" + top: "inception_4a/3x3" + name: "inception_4a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 208 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/3x3" + top: "inception_4a/3x3" + name: "inception_4a/relu_3x3" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/5x5_reduce" + name: "inception_4a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 16 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/5x5_reduce" + top: "inception_4a/5x5_reduce" + name: "inception_4a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4a/5x5_reduce" + top: "inception_4a/5x5" + name: "inception_4a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 48 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/5x5" + top: "inception_4a/5x5" + name: "inception_4a/relu_5x5" + type: RELU +} +layers { + bottom: "pool3/3x3_s2" + top: "inception_4a/pool" + name: "inception_4a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4a/pool" + top: "inception_4a/pool_proj" + name: "inception_4a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4a/pool_proj" + top: "inception_4a/pool_proj" + name: "inception_4a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4a/1x1" + bottom: "inception_4a/3x3" + bottom: "inception_4a/5x5" + bottom: "inception_4a/pool_proj" + top: "inception_4a/output" + name: "inception_4a/output" + type: CONCAT +} +layers { + bottom: "inception_4a/output" + top: "loss1/ave_pool" + name: "loss1/ave_pool" + type: POOLING + pooling_param { + pool: AVE + kernel_size: 5 + stride: 3 + } +} +layers { + bottom: "loss1/ave_pool" + top: "loss1/conv" + name: "loss1/conv" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.08 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "loss1/conv" + top: "loss1/conv" + name: "loss1/relu_conv" + type: RELU +} +layers { + bottom: "loss1/conv" + top: "loss1/fc" + name: "loss1/fc" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1024 + weight_filler { + type: "xavier" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "loss1/fc" + top: "loss1/fc" + name: "loss1/relu_fc" + type: RELU +} +layers { + bottom: "loss1/fc" + top: "loss1/fc" + name: "loss1/drop_fc" + type: DROPOUT + dropout_param { + dropout_ratio: 0.7 + } +} +layers { + bottom: "loss1/fc" + top: "loss1/classifier" + name: "loss1/classifier" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "xavier" + std: 0.0009765625 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "loss1/classifier" + bottom: "label" + top: "loss1/loss1" + name: "loss1/loss" + type: SOFTMAX_LOSS + loss_weight: 0.3 +} +layers { + bottom: "loss1/classifier" + bottom: "label" + top: "loss1/top-1" + name: "loss1/top-1" + type: ACCURACY + include { + phase: TEST + } +} +layers { + bottom: "loss1/classifier" + bottom: "label" + top: "loss1/top-5" + name: "loss1/top-5" + type: ACCURACY + accuracy_param { + top_k: 5 + } + include { + phase: TEST + } +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/1x1" + name: "inception_4b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/1x1" + top: "inception_4b/1x1" + name: "inception_4b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/3x3_reduce" + name: "inception_4b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 112 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/3x3_reduce" + top: "inception_4b/3x3_reduce" + name: "inception_4b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4b/3x3_reduce" + top: "inception_4b/3x3" + name: "inception_4b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 224 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/3x3" + top: "inception_4b/3x3" + name: "inception_4b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/5x5_reduce" + name: "inception_4b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 24 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/5x5_reduce" + top: "inception_4b/5x5_reduce" + name: "inception_4b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4b/5x5_reduce" + top: "inception_4b/5x5" + name: "inception_4b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/5x5" + top: "inception_4b/5x5" + name: "inception_4b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4a/output" + top: "inception_4b/pool" + name: "inception_4b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4b/pool" + top: "inception_4b/pool_proj" + name: "inception_4b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4b/pool_proj" + top: "inception_4b/pool_proj" + name: "inception_4b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4b/1x1" + bottom: "inception_4b/3x3" + bottom: "inception_4b/5x5" + bottom: "inception_4b/pool_proj" + top: "inception_4b/output" + name: "inception_4b/output" + type: CONCAT +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/1x1" + name: "inception_4c/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/1x1" + top: "inception_4c/1x1" + name: "inception_4c/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/3x3_reduce" + name: "inception_4c/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/3x3_reduce" + top: "inception_4c/3x3_reduce" + name: "inception_4c/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4c/3x3_reduce" + top: "inception_4c/3x3" + name: "inception_4c/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/3x3" + top: "inception_4c/3x3" + name: "inception_4c/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/5x5_reduce" + name: "inception_4c/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 24 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/5x5_reduce" + top: "inception_4c/5x5_reduce" + name: "inception_4c/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4c/5x5_reduce" + top: "inception_4c/5x5" + name: "inception_4c/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/5x5" + top: "inception_4c/5x5" + name: "inception_4c/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4b/output" + top: "inception_4c/pool" + name: "inception_4c/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4c/pool" + top: "inception_4c/pool_proj" + name: "inception_4c/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4c/pool_proj" + top: "inception_4c/pool_proj" + name: "inception_4c/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4c/1x1" + bottom: "inception_4c/3x3" + bottom: "inception_4c/5x5" + bottom: "inception_4c/pool_proj" + top: "inception_4c/output" + name: "inception_4c/output" + type: CONCAT +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/1x1" + name: "inception_4d/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 112 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/1x1" + top: "inception_4d/1x1" + name: "inception_4d/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/3x3_reduce" + name: "inception_4d/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 144 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/3x3_reduce" + top: "inception_4d/3x3_reduce" + name: "inception_4d/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4d/3x3_reduce" + top: "inception_4d/3x3" + name: "inception_4d/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 288 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/3x3" + top: "inception_4d/3x3" + name: "inception_4d/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/5x5_reduce" + name: "inception_4d/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/5x5_reduce" + top: "inception_4d/5x5_reduce" + name: "inception_4d/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4d/5x5_reduce" + top: "inception_4d/5x5" + name: "inception_4d/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/5x5" + top: "inception_4d/5x5" + name: "inception_4d/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4c/output" + top: "inception_4d/pool" + name: "inception_4d/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4d/pool" + top: "inception_4d/pool_proj" + name: "inception_4d/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4d/pool_proj" + top: "inception_4d/pool_proj" + name: "inception_4d/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4d/1x1" + bottom: "inception_4d/3x3" + bottom: "inception_4d/5x5" + bottom: "inception_4d/pool_proj" + top: "inception_4d/output" + name: "inception_4d/output" + type: CONCAT +} +layers { + bottom: "inception_4d/output" + top: "loss2/ave_pool" + name: "loss2/ave_pool" + type: POOLING + pooling_param { + pool: AVE + kernel_size: 5 + stride: 3 + } +} +layers { + bottom: "loss2/ave_pool" + top: "loss2/conv" + name: "loss2/conv" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.08 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "loss2/conv" + top: "loss2/conv" + name: "loss2/relu_conv" + type: RELU +} +layers { + bottom: "loss2/conv" + top: "loss2/fc" + name: "loss2/fc" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1024 + weight_filler { + type: "xavier" + std: 0.02 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "loss2/fc" + top: "loss2/fc" + name: "loss2/relu_fc" + type: RELU +} +layers { + bottom: "loss2/fc" + top: "loss2/fc" + name: "loss2/drop_fc" + type: DROPOUT + dropout_param { + dropout_ratio: 0.7 + } +} +layers { + bottom: "loss2/fc" + top: "loss2/classifier" + name: "loss2/classifier" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "xavier" + std: 0.0009765625 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "loss2/classifier" + bottom: "label" + top: "loss2/loss1" + name: "loss2/loss" + type: SOFTMAX_LOSS + loss_weight: 0.3 +} +layers { + bottom: "loss2/classifier" + bottom: "label" + top: "loss2/top-1" + name: "loss2/top-1" + type: ACCURACY + include { + phase: TEST + } +} +layers { + bottom: "loss2/classifier" + bottom: "label" + top: "loss2/top-5" + name: "loss2/top-5" + type: ACCURACY + accuracy_param { + top_k: 5 + } + include { + phase: TEST + } +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/1x1" + name: "inception_4e/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/1x1" + top: "inception_4e/1x1" + name: "inception_4e/relu_1x1" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/3x3_reduce" + name: "inception_4e/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/3x3_reduce" + top: "inception_4e/3x3_reduce" + name: "inception_4e/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_4e/3x3_reduce" + top: "inception_4e/3x3" + name: "inception_4e/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 320 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/3x3" + top: "inception_4e/3x3" + name: "inception_4e/relu_3x3" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/5x5_reduce" + name: "inception_4e/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/5x5_reduce" + top: "inception_4e/5x5_reduce" + name: "inception_4e/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_4e/5x5_reduce" + top: "inception_4e/5x5" + name: "inception_4e/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/5x5" + top: "inception_4e/5x5" + name: "inception_4e/relu_5x5" + type: RELU +} +layers { + bottom: "inception_4d/output" + top: "inception_4e/pool" + name: "inception_4e/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_4e/pool" + top: "inception_4e/pool_proj" + name: "inception_4e/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_4e/pool_proj" + top: "inception_4e/pool_proj" + name: "inception_4e/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_4e/1x1" + bottom: "inception_4e/3x3" + bottom: "inception_4e/5x5" + bottom: "inception_4e/pool_proj" + top: "inception_4e/output" + name: "inception_4e/output" + type: CONCAT +} +layers { + bottom: "inception_4e/output" + top: "pool4/3x3_s2" + name: "pool4/3x3_s2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/1x1" + name: "inception_5a/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/1x1" + top: "inception_5a/1x1" + name: "inception_5a/relu_1x1" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/3x3_reduce" + name: "inception_5a/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 160 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/3x3_reduce" + top: "inception_5a/3x3_reduce" + name: "inception_5a/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_5a/3x3_reduce" + top: "inception_5a/3x3" + name: "inception_5a/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 320 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/3x3" + top: "inception_5a/3x3" + name: "inception_5a/relu_3x3" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/5x5_reduce" + name: "inception_5a/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 32 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/5x5_reduce" + top: "inception_5a/5x5_reduce" + name: "inception_5a/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_5a/5x5_reduce" + top: "inception_5a/5x5" + name: "inception_5a/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/5x5" + top: "inception_5a/5x5" + name: "inception_5a/relu_5x5" + type: RELU +} +layers { + bottom: "pool4/3x3_s2" + top: "inception_5a/pool" + name: "inception_5a/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_5a/pool" + top: "inception_5a/pool_proj" + name: "inception_5a/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5a/pool_proj" + top: "inception_5a/pool_proj" + name: "inception_5a/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_5a/1x1" + bottom: "inception_5a/3x3" + bottom: "inception_5a/5x5" + bottom: "inception_5a/pool_proj" + top: "inception_5a/output" + name: "inception_5a/output" + type: CONCAT +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/1x1" + name: "inception_5b/1x1" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/1x1" + top: "inception_5b/1x1" + name: "inception_5b/relu_1x1" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/3x3_reduce" + name: "inception_5b/3x3_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 192 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.09 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/3x3_reduce" + top: "inception_5b/3x3_reduce" + name: "inception_5b/relu_3x3_reduce" + type: RELU +} +layers { + bottom: "inception_5b/3x3_reduce" + top: "inception_5b/3x3" + name: "inception_5b/3x3" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/3x3" + top: "inception_5b/3x3" + name: "inception_5b/relu_3x3" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/5x5_reduce" + name: "inception_5b/5x5_reduce" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 48 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.2 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/5x5_reduce" + top: "inception_5b/5x5_reduce" + name: "inception_5b/relu_5x5_reduce" + type: RELU +} +layers { + bottom: "inception_5b/5x5_reduce" + top: "inception_5b/5x5" + name: "inception_5b/5x5" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + pad: 2 + kernel_size: 5 + weight_filler { + type: "xavier" + std: 0.03 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/5x5" + top: "inception_5b/5x5" + name: "inception_5b/relu_5x5" + type: RELU +} +layers { + bottom: "inception_5a/output" + top: "inception_5b/pool" + name: "inception_5b/pool" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 3 + stride: 1 + pad: 1 + } +} +layers { + bottom: "inception_5b/pool" + top: "inception_5b/pool_proj" + name: "inception_5b/pool_proj" + type: CONVOLUTION + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 1 + weight_filler { + type: "xavier" + std: 0.1 + } + bias_filler { + type: "constant" + value: 0.2 + } + } +} +layers { + bottom: "inception_5b/pool_proj" + top: "inception_5b/pool_proj" + name: "inception_5b/relu_pool_proj" + type: RELU +} +layers { + bottom: "inception_5b/1x1" + bottom: "inception_5b/3x3" + bottom: "inception_5b/5x5" + bottom: "inception_5b/pool_proj" + top: "inception_5b/output" + name: "inception_5b/output" + type: CONCAT +} +layers { + bottom: "inception_5b/output" + top: "pool5/7x7_s1" + name: "pool5/7x7_s1" + type: POOLING + pooling_param { + pool: AVE + kernel_size: 7 + stride: 1 + } +} +layers { + bottom: "pool5/7x7_s1" + top: "pool5/7x7_s1" + name: "pool5/drop_7x7_s1" + type: DROPOUT + dropout_param { + dropout_ratio: 0.4 + } +} +layers { + bottom: "pool5/7x7_s1" + top: "loss3/classifier" + name: "loss3/classifier" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "loss3/classifier" + bottom: "label" + top: "loss3/loss3" + name: "loss3/loss3" + type: SOFTMAX_LOSS + loss_weight: 1 +} +layers { + bottom: "loss3/classifier" + bottom: "label" + top: "loss3/top-1" + name: "loss3/top-1" + type: ACCURACY + include { + phase: TEST + } +} +layers { + bottom: "loss3/classifier" + bottom: "label" + top: "loss3/top-5" + name: "loss3/top-5" + type: ACCURACY + accuracy_param { + top_k: 5 + } + include { + phase: TEST + } +} diff --git a/caffe-crfrnn/models/bvlc_reference_caffenet/deploy.prototxt b/caffe-crfrnn/models/bvlc_reference_caffenet/deploy.prototxt new file mode 100644 index 00000000..4e494f42 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_caffenet/deploy.prototxt @@ -0,0 +1,212 @@ +name: "CaffeNet" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 227 +input_dim: 227 +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + inner_product_param { + num_output: 4096 + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + inner_product_param { + num_output: 4096 + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8" + inner_product_param { + num_output: 1000 + } +} +layers { + name: "prob" + type: SOFTMAX + bottom: "fc8" + top: "prob" +} diff --git a/caffe-crfrnn/models/bvlc_reference_caffenet/readme.md b/caffe-crfrnn/models/bvlc_reference_caffenet/readme.md new file mode 100644 index 00000000..b867e738 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_caffenet/readme.md @@ -0,0 +1,27 @@ +--- +name: BVLC CaffeNet Model +caffemodel: bvlc_reference_caffenet.caffemodel +caffemodel_url: http://dl.caffe.berkeleyvision.org/bvlc_reference_caffenet.caffemodel +license: non-commercial +sha1: 4c8d77deb20ea792f84eb5e6d0a11ca0a8660a46 +caffe_commit: 709dc15af4a06bebda027c1eb2b3f3e3375d5077 +--- + +This model is the result of following the Caffe [ImageNet model training instructions](http://caffe.berkeleyvision.org/gathered/examples/imagenet.html). +It is a replication of the model described in the [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) publication with some differences: + +- not training with the relighting data-augmentation; +- the order of pooling and normalization layers is switched (in CaffeNet, pooling is done before normalization). + +This model is snapshot of iteration 310,000. +The best validation performance during training was iteration 313,000 with validation accuracy 57.412% and loss 1.82328. +This model obtains a top-1 accuracy 57.4% and a top-5 accuracy 80.4% on the validation set, using just the center crop. +(Using the average of 10 crops, (4 + 1 center) * 2 mirror, should obtain a bit higher accuracy still.) + +This model was trained by Jeff Donahue @jeffdonahue + +## License + +The data used to train this model comes from the ImageNet project, which distributes its database to researchers who agree to a following term of access: +"Researcher shall use the Database only for non-commercial research and educational purposes." +Accordingly, this model is distributed under a non-commercial license. diff --git a/caffe-crfrnn/models/bvlc_reference_caffenet/solver.prototxt b/caffe-crfrnn/models/bvlc_reference_caffenet/solver.prototxt new file mode 100644 index 00000000..af1315ba --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_caffenet/solver.prototxt @@ -0,0 +1,14 @@ +net: "models/bvlc_reference_caffenet/train_val.prototxt" +test_iter: 1000 +test_interval: 1000 +base_lr: 0.01 +lr_policy: "step" +gamma: 0.1 +stepsize: 100000 +display: 20 +max_iter: 450000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "models/bvlc_reference_caffenet/caffenet_train" +solver_mode: GPU diff --git a/caffe-crfrnn/models/bvlc_reference_caffenet/train_val.prototxt b/caffe-crfrnn/models/bvlc_reference_caffenet/train_val.prototxt new file mode 100644 index 00000000..00fcc080 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_caffenet/train_val.prototxt @@ -0,0 +1,362 @@ +name: "CaffeNet" +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/imagenet/ilsvrc12_train_lmdb" + backend: LMDB + batch_size: 256 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } +# mean pixel / channel-wise mean instead of mean image +# transform_param { +# crop_size: 227 +# mean_value: 104 +# mean_value: 117 +# mean_value: 123 +# mirror: true +# } + include: { phase: TRAIN } +} +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "examples/imagenet/ilsvrc12_val_lmdb" + backend: LMDB + batch_size: 50 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + } +# mean pixel / channel-wise mean instead of mean image +# transform_param { +# crop_size: 227 +# mean_value: 104 +# mean_value: 117 +# mean_value: 123 +# mirror: true +# } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "fc8" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "fc8" + bottom: "label" + top: "loss" +} diff --git a/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/deploy.prototxt b/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/deploy.prototxt new file mode 100644 index 00000000..ef75a0a5 --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/deploy.prototxt @@ -0,0 +1,207 @@ +name: "R-CNN-ilsvrc13" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 227 +input_dim: 227 +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + inner_product_param { + num_output: 4096 + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + inner_product_param { + num_output: 4096 + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +# R-CNN classification layer made from R-CNN ILSVRC13 SVMs. +layers { + name: "fc-rcnn" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc-rcnn" + inner_product_param { + num_output: 200 + } +} diff --git a/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/readme.md b/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/readme.md new file mode 100644 index 00000000..5d4bc5af --- /dev/null +++ b/caffe-crfrnn/models/bvlc_reference_rcnn_ilsvrc13/readme.md @@ -0,0 +1,22 @@ +--- +name: BVLC Reference RCNN ILSVRC13 Model +caffemodel: bvlc_reference_rcnn_ilsvrc13.caffemodel +caffemodel_url: http://dl.caffe.berkeleyvision.org/bvlc_reference_rcnn_ilsvrc13.caffemodel +license: non-commercial +sha1: bdd8abb885819cba5e2fe1eb36235f2319477e64 +caffe_commit: a7e397abbda52c0b90323c23ab95bdeabee90a98 +--- + +The pure Caffe instantiation of the [R-CNN](https://github.com/rbgirshick/rcnn) model for ILSVRC13 detection. +This model was made by transplanting the R-CNN SVM classifiers into a `fc-rcnn` classification layer, provided here as an off-the-shelf Caffe detector. +Try the [detection example](http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/detection.ipynb) to see it in action. + +*N.B. For research purposes, make use of the official R-CNN package and not this example.* + +This model was trained by Ross Girshick @rbgirshick + +## License + +The data used to train this model comes from the ImageNet project, which distributes its database to researchers who agree to a following term of access: +"Researcher shall use the Database only for non-commercial research and educational purposes." +Accordingly, this model is distributed under a non-commercial license. diff --git a/caffe-crfrnn/models/finetune_flickr_style/deploy.prototxt b/caffe-crfrnn/models/finetune_flickr_style/deploy.prototxt new file mode 100644 index 00000000..aa2ad961 --- /dev/null +++ b/caffe-crfrnn/models/finetune_flickr_style/deploy.prototxt @@ -0,0 +1,310 @@ +name: "FlickrStyleCaffeNet" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 227 +input_dim: 227 +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + # Note that blobs_lr can be set to 0 to disable any fine-tuning of this, and any other, layer + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8_flickr" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8_flickr" + # blobs_lr is set to higher than for other layers, because this layer is starting from random while the others are already trained + blobs_lr: 10 + blobs_lr: 20 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 20 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "prob" + type: SOFTMAX + bottom: "fc8_flickr" + top: "prob" +} diff --git a/caffe-crfrnn/models/finetune_flickr_style/readme.md b/caffe-crfrnn/models/finetune_flickr_style/readme.md new file mode 100644 index 00000000..aac7f7c9 --- /dev/null +++ b/caffe-crfrnn/models/finetune_flickr_style/readme.md @@ -0,0 +1,24 @@ +--- +name: Finetuning CaffeNet on Flickr Style +caffemodel: finetune_flickr_style.caffemodel +caffemodel_url: http://dl.caffe.berkeleyvision.org/finetune_flickr_style.caffemodel +license: non-commercial +sha1: b61b5cef7d771b53b0c488e78d35ccadc073e9cf +caffe_commit: 737ea5e936821b5c69f9c3952d72693ae5843370 +gist_id: 034c6ac3865563b69e60 +--- + +This model is trained exactly as described in `docs/finetune_flickr_style/readme.md`, using all 80000 images. +The final performance: + + I1017 07:36:17.370688 31333 solver.cpp:228] Iteration 100000, loss = 0.757952 + I1017 07:36:17.370730 31333 solver.cpp:247] Iteration 100000, Testing net (#0) + I1017 07:36:34.248730 31333 solver.cpp:298] Test net output #0: accuracy = 0.3916 + +This model was trained by Sergey Karayev @sergeyk + +## License + +The Flickr Style dataset contains only URLs to images. +Some of the images may have copyright. +Training a category-recognition model for research/non-commercial use may constitute fair use of this data, but the result should not be used for commercial purposes. diff --git a/caffe-crfrnn/models/finetune_flickr_style/solver.prototxt b/caffe-crfrnn/models/finetune_flickr_style/solver.prototxt new file mode 100644 index 00000000..5e189bc9 --- /dev/null +++ b/caffe-crfrnn/models/finetune_flickr_style/solver.prototxt @@ -0,0 +1,17 @@ +net: "models/finetune_flickr_style/train_val.prototxt" +test_iter: 100 +test_interval: 1000 +# lr for fine-tuning should be lower than when starting from scratch +base_lr: 0.001 +lr_policy: "step" +gamma: 0.1 +# stepsize should also be lower, as we're closer to being done +stepsize: 20000 +display: 20 +max_iter: 100000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "models/finetune_flickr_style/finetune_flickr_style" +# uncomment the following to default to CPU mode solving +# solver_mode: CPU diff --git a/caffe-crfrnn/models/finetune_flickr_style/train_val.prototxt b/caffe-crfrnn/models/finetune_flickr_style/train_val.prototxt new file mode 100644 index 00000000..7155c492 --- /dev/null +++ b/caffe-crfrnn/models/finetune_flickr_style/train_val.prototxt @@ -0,0 +1,349 @@ +name: "FlickrStyleCaffeNet" +layers { + name: "data" + type: IMAGE_DATA + top: "data" + top: "label" + image_data_param { + source: "data/flickr_style/train.txt" + batch_size: 50 + new_height: 256 + new_width: 256 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } + include: { phase: TRAIN } +} +layers { + name: "data" + type: IMAGE_DATA + top: "data" + top: "label" + image_data_param { + source: "data/flickr_style/test.txt" + batch_size: 50 + new_height: 256 + new_width: 256 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + # Note that blobs_lr can be set to 0 to disable any fine-tuning of this, and any other, layer + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8_flickr" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8_flickr" + # blobs_lr is set to higher than for other layers, because this layer is starting from random while the others are already trained + blobs_lr: 10 + blobs_lr: 20 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 20 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "fc8_flickr" + bottom: "label" +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "fc8_flickr" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} diff --git a/caffe-crfrnn/python/CMakeLists.txt b/caffe-crfrnn/python/CMakeLists.txt new file mode 100644 index 00000000..0e2bc7e6 --- /dev/null +++ b/caffe-crfrnn/python/CMakeLists.txt @@ -0,0 +1,34 @@ +if(NOT HAVE_PYTHON) + message(STATUS "Python interface is disabled or not all required dependecies found. Building without it...") + return() +endif() + +include_directories(${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${Boost_INCLUDE_DIRS}) +file(GLOB_RECURSE python_srcs ${PROJECT_SOURCE_DIR}/python/*.cpp) + +add_library(pycaffe SHARED ${python_srcs}) +target_link_libraries(pycaffe ${Caffe_LINK} ${PYTHON_LIBRARIES} ${Boost_LIBRARIES}) +set_target_properties(pycaffe PROPERTIES PREFIX "" OUTPUT_NAME "_caffe") +caffe_default_properties(pycaffe) + +if(UNIX OR APPLE) + set(__linkname "${PROJECT_SOURCE_DIR}/python/caffe/_caffe.so") + add_custom_command(TARGET pycaffe POST_BUILD + COMMAND ln -sf $ "${__linkname}" + COMMAND ${CMAKE_COMMAND} -E make_directory ${PROJECT_SOURCE_DIR}/python/caffe/proto + COMMAND touch ${PROJECT_SOURCE_DIR}/python/caffe/proto/__init__.py + COMMAND cp ${proto_gen_folder}/*.py ${PROJECT_SOURCE_DIR}/python/caffe/proto/ + COMMENT "Creating symlink ${__linkname} -> ${PROJECT_BINARY_DIR}/lib/_caffe${Caffe_POSTFIX}.so") +endif() + +# ---[ Install +file(GLOB files1 *.py requirements.txt) +install(FILES ${files1} DESTINATION python) + +file(GLOB files2 caffe/*.py) +install(FILES ${files2} DESTINATION python/caffe) +install(TARGETS pycaffe DESTINATION python/caffe) +install(DIRECTORY caffe/imagenet caffe/proto caffe/test DESTINATION python/caffe) + + + diff --git a/caffe-crfrnn/python/caffe/__init__.py b/caffe-crfrnn/python/caffe/__init__.py new file mode 100755 index 00000000..59828d9a --- /dev/null +++ b/caffe-crfrnn/python/caffe/__init__.py @@ -0,0 +1,5 @@ +from .pycaffe import Net, SGDSolver +from .classifier import Classifier +from .detector import Detector +from .segmenter import Segmenter +import io diff --git a/caffe-crfrnn/python/caffe/_caffe.cpp b/caffe-crfrnn/python/caffe/_caffe.cpp new file mode 100644 index 00000000..156b5187 --- /dev/null +++ b/caffe-crfrnn/python/caffe/_caffe.cpp @@ -0,0 +1,218 @@ +// pycaffe provides a wrapper of the caffe::Net class as well as some +// caffe::Caffe functions so that one could easily call it from Python. +// Note that for Python, we will simply use float as the data type. + +#include // NOLINT(build/include_alpha) + +#include +#include + +// these need to be included after boost on OS X +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT + +#include "_caffe.hpp" +#include "caffe/caffe.hpp" + +// Temporary solution for numpy < 1.7 versions: old macro, no promises. +// You're strongly advised to upgrade to >= 1.7. +#ifndef NPY_ARRAY_C_CONTIGUOUS +#define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS +#define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x)) +#endif + +namespace caffe { + +// for convenience, check that input files can be opened, and raise an +// exception that boost will send to Python if not (caffe could still crash +// later if the input files are disturbed before they are actually used, but +// this saves frustration in most cases) +static void CheckFile(const string& filename) { + std::ifstream f(filename.c_str()); + if (!f.good()) { + f.close(); + throw std::runtime_error("Could not open file " + filename); + } + f.close(); +} + +bp::object PyBlobWrap::get_data() { + npy_intp dims[] = {num(), channels(), height(), width()}; + + PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + blob_->mutable_cpu_data()); + PyArray_SetBaseObject(reinterpret_cast(obj), self_); + Py_INCREF(self_); + bp::handle<> h(obj); + + return bp::object(h); +} + +bp::object PyBlobWrap::get_diff() { + npy_intp dims[] = {num(), channels(), height(), width()}; + + PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + blob_->mutable_cpu_diff()); + PyArray_SetBaseObject(reinterpret_cast(obj), self_); + Py_INCREF(self_); + bp::handle<> h(obj); + + return bp::object(h); +} + +PyNet::PyNet(string param_file, string pretrained_param_file) { + Init(param_file); + CheckFile(pretrained_param_file); + net_->CopyTrainedLayersFrom(pretrained_param_file); +} + +void PyNet::Init(string param_file) { + CheckFile(param_file); + net_.reset(new Net(param_file)); +} + +void PyNet::check_contiguous_array(PyArrayObject* arr, string name, + int channels, int height, int width) { + if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) { + throw std::runtime_error(name + " must be C contiguous"); + } + if (PyArray_NDIM(arr) != 4) { + throw std::runtime_error(name + " must be 4-d"); + } + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + throw std::runtime_error(name + " must be float32"); + } + if (PyArray_DIMS(arr)[1] != channels) { + throw std::runtime_error(name + " has wrong number of channels"); + } + if (PyArray_DIMS(arr)[2] != height) { + throw std::runtime_error(name + " has wrong height"); + } + if (PyArray_DIMS(arr)[3] != width) { + throw std::runtime_error(name + " has wrong width"); + } +} + +void PyNet::set_input_arrays(bp::object data_obj, bp::object labels_obj) { + // check that this network has an input MemoryDataLayer + shared_ptr > md_layer = + boost::dynamic_pointer_cast >(net_->layers()[0]); + if (!md_layer) { + throw std::runtime_error("set_input_arrays may only be called if the" + " first layer is a MemoryDataLayer"); + } + + // check that we were passed appropriately-sized contiguous memory + PyArrayObject* data_arr = + reinterpret_cast(data_obj.ptr()); + PyArrayObject* labels_arr = + reinterpret_cast(labels_obj.ptr()); + check_contiguous_array(data_arr, "data array", md_layer->channels(), + md_layer->height(), md_layer->width()); + check_contiguous_array(labels_arr, "labels array", 1, 1, 1); + if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) { + throw std::runtime_error("data and labels must have the same first" + " dimension"); + } + if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) { + throw std::runtime_error("first dimensions of input arrays must be a" + " multiple of batch size"); + } + + // hold references + input_data_ = data_obj; + input_labels_ = labels_obj; + + md_layer->Reset(static_cast(PyArray_DATA(data_arr)), + static_cast(PyArray_DATA(labels_arr)), + PyArray_DIMS(data_arr)[0]); +} + +PySGDSolver::PySGDSolver(const string& param_file) { + // as in PyNet, (as a convenience, not a guarantee), create a Python + // exception if param_file can't be opened + CheckFile(param_file); + solver_.reset(new SGDSolver(param_file)); + // we need to explicitly store the net wrapper, rather than constructing + // it on the fly, so that it can hold references to Python objects + net_.reset(new PyNet(solver_->net())); + for (int i = 0; i < solver_->test_nets().size(); ++i) { + test_nets_.push_back(boost::make_shared(solver_->test_nets()[i])); + } +} + +void PySGDSolver::SolveResume(const string& resume_file) { + CheckFile(resume_file); + return solver_->Solve(resume_file); +} + +BOOST_PYTHON_MODULE(_caffe) { + // below, we prepend an underscore to methods that will be replaced + // in Python + bp::class_ >( + "Net", bp::init()) + .def(bp::init()) + .def("copy_from", &PyNet::CopyTrainedLayersFrom) + .def("share_with", &PyNet::ShareTrainedLayersWith) + .def("_forward", &PyNet::Forward) + .def("_backward", &PyNet::Backward) + .def("reshape", &PyNet::Reshape) + .def("set_mode_cpu", &PyNet::set_mode_cpu) + .def("set_mode_gpu", &PyNet::set_mode_gpu) + .def("set_phase_train", &PyNet::set_phase_train) + .def("set_phase_test", &PyNet::set_phase_test) + .def("set_device", &PyNet::set_device) + .add_property("_blobs", &PyNet::blobs) + .add_property("layers", &PyNet::layers) + .add_property("_blob_names", &PyNet::blob_names) + .add_property("_layer_names", &PyNet::layer_names) + .add_property("inputs", &PyNet::inputs) + .add_property("outputs", &PyNet::outputs) + .add_property("mean", &PyNet::mean_) + .add_property("input_scale", &PyNet::input_scale_) + .add_property("raw_scale", &PyNet::raw_scale_) + .add_property("channel_swap", &PyNet::channel_swap_) + .def("_set_input_arrays", &PyNet::set_input_arrays) + .def("save", &PyNet::save); + + bp::class_, PyBlobWrap>( + "Blob", bp::no_init) + .add_property("num", &PyBlob::num) + .add_property("channels", &PyBlob::channels) + .add_property("height", &PyBlob::height) + .add_property("width", &PyBlob::width) + .add_property("count", &PyBlob::count) + .def("reshape", &PyBlob::Reshape) + .add_property("data", &PyBlobWrap::get_data) + .add_property("diff", &PyBlobWrap::get_diff); + + bp::class_( + "Layer", bp::no_init) + .add_property("blobs", &PyLayer::blobs); + + bp::class_( + "SGDSolver", bp::init()) + .add_property("net", &PySGDSolver::net) + .add_property("test_nets", &PySGDSolver::test_nets) + .add_property("iter", &PySGDSolver::iter) + .def("solve", &PySGDSolver::Solve) + .def("solve", &PySGDSolver::SolveResume) + .def("step", &PySGDSolver::Step); + + bp::class_ > >("NetVec") + .def(bp::vector_indexing_suite >, true>()); + + bp::class_ > >("BlobVec") + .def(bp::vector_indexing_suite >, true>()); + + bp::class_ >("LayerVec") + .def(bp::vector_indexing_suite, true>()); + + bp::class_ >("StringVec") + .def(bp::vector_indexing_suite >()); + + import_array(); +} + +} // namespace caffe diff --git a/caffe-crfrnn/python/caffe/_caffe.hpp b/caffe-crfrnn/python/caffe/_caffe.hpp new file mode 100644 index 00000000..a5cef74a --- /dev/null +++ b/caffe-crfrnn/python/caffe/_caffe.hpp @@ -0,0 +1,199 @@ +#ifndef PYTHON_CAFFE__CAFFE_HPP_ +#define PYTHON_CAFFE__CAFFE_HPP_ + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include // NOLINT(build/include_alpha) + +#include +#include +#include + +// these need to be included after boost on OS X +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) + +#include "caffe/caffe.hpp" + +namespace bp = boost::python; +using boost::shared_ptr; + +namespace caffe { + +// wrap shared_ptr in a class that we construct in C++ and pass +// to Python +template +class PyBlob { + public: + explicit PyBlob(const shared_ptr > &blob) + : blob_(blob) {} + + int num() const { return blob_->num(); } + int channels() const { return blob_->channels(); } + int height() const { return blob_->height(); } + int width() const { return blob_->width(); } + int count() const { return blob_->count(); } + void Reshape(const int n, const int c, const int h, const int w) { + return blob_->Reshape(n, c, h, w); + } + + // this is here only to satisfy boost's vector_indexing_suite + bool operator == (const PyBlob &other) { + return this->blob_ == other.blob_; + } + + protected: + shared_ptr > blob_; +}; + +// We need another wrapper (used as boost::python's HeldType) that receives a +// self PyObject * which we can use as ndarray.base, so that data/diff memory +// is not freed while still being used in Python. +class PyBlobWrap : public PyBlob { + public: + PyBlobWrap(PyObject *p, const PyBlob &blob) + : PyBlob(blob), self_(p) {} + + bp::object get_data(); + bp::object get_diff(); + + private: + PyObject *self_; +}; + +class PyLayer { + public: + explicit PyLayer(const shared_ptr > &layer) + : layer_(layer) {} + + vector > blobs() { + return vector >(layer_->blobs().begin(), + layer_->blobs().end()); + } + + // this is here only to satisfy boost's vector_indexing_suite + bool operator == (const PyLayer &other) { + return this->layer_ == other.layer_; + } + + protected: + shared_ptr > layer_; +}; + +class PyNet { + public: + // For cases where parameters will be determined later by the Python user, + // create a Net with unallocated parameters (which will not be zero-filled + // when accessed). + explicit PyNet(string param_file) { Init(param_file); } + PyNet(string param_file, string pretrained_param_file); + explicit PyNet(shared_ptr > net) + : net_(net) {} + virtual ~PyNet() {} + + void Init(string param_file); + + + // Generate Python exceptions for badly shaped or discontiguous arrays. + inline void check_contiguous_array(PyArrayObject* arr, string name, + int channels, int height, int width); + + void CopyTrainedLayersFrom(const string filename) { + net_->CopyTrainedLayersFrom(filename); + } + void ShareTrainedLayersWith(PyNet* other) { + net_->ShareTrainedLayersWith(other->net_.get()); + } + void Forward(int start, int end) { net_->ForwardFromTo(start, end); } + void Backward(int start, int end) { net_->BackwardFromTo(start, end); } + void Reshape() { net_->Reshape(); } + + void set_input_arrays(bp::object data_obj, bp::object labels_obj); + + // Save the network weights to binary proto for net surgeries. + void save(string filename) { + NetParameter net_param; + net_->ToProto(&net_param, false); + WriteProtoToBinaryFile(net_param, filename.c_str()); + } + + // The caffe::Caffe utility functions. + void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); } + void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); } + void set_phase_train() { Caffe::set_phase(Caffe::TRAIN); } + void set_phase_test() { Caffe::set_phase(Caffe::TEST); } + void set_device(int device_id) { Caffe::SetDevice(device_id); } + + vector > blobs() { + return vector >(net_->blobs().begin(), net_->blobs().end()); + } + + vector layers() { + return vector(net_->layers().begin(), net_->layers().end()); + } + + vector blob_names() { return net_->blob_names(); } + vector layer_names() { return net_->layer_names(); } + + bp::list inputs() { + bp::list input_blob_names; + for (int i = 0; i < net_->input_blob_indices().size(); ++i) { + input_blob_names.append( + net_->blob_names()[net_->input_blob_indices()[i]]); + } + return input_blob_names; + } + + bp::list outputs() { + bp::list output_blob_names; + for (int i = 0; i < net_->output_blob_indices().size(); ++i) { + output_blob_names.append( + net_->blob_names()[net_->output_blob_indices()[i]]); + } + return output_blob_names; + } + + // Input preprocessing configuration attributes. These are public for + // direct access from Python. + bp::dict mean_; + bp::dict input_scale_; + bp::dict raw_scale_; + bp::dict channel_swap_; + + // this is here only to satisfy boost's vector_indexing_suite + bool operator == (const PyNet &other) { + return this->net_ == other.net_; + } + + protected: + // The pointer to the internal caffe::Net instance. + shared_ptr > net_; + // if taking input from an ndarray, we need to hold references + bp::object input_data_; + bp::object input_labels_; +}; + +class PySGDSolver { + public: + explicit PySGDSolver(const string& param_file); + + shared_ptr net() { return net_; } + vector > test_nets() { return test_nets_; } + int iter() { return solver_->iter(); } + void Solve() { return solver_->Solve(); } + void Step(int iters) { solver_->Step(iters); } + void SolveResume(const string& resume_file); + + protected: + shared_ptr net_; + vector > test_nets_; + shared_ptr > solver_; +}; + +// Declare the module init function created by boost::python, so that we can +// use this module from C++ when embedding Python. +PyMODINIT_FUNC init_caffe(void); + +} // namespace caffe + +#endif diff --git a/caffe-crfrnn/python/caffe/classifier.py b/caffe-crfrnn/python/caffe/classifier.py new file mode 100644 index 00000000..fe471ca1 --- /dev/null +++ b/caffe-crfrnn/python/caffe/classifier.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +""" +Classifier is an image classifier specialization of Net. +""" + +import numpy as np + +import caffe + + +class Classifier(caffe.Net): + """ + Classifier extends Net for image class prediction + by scaling, center cropping, or oversampling. + """ + def __init__(self, model_file, pretrained_file, image_dims=None, + gpu=False, mean=None, input_scale=None, raw_scale=None, + channel_swap=None): + """ + Take + image_dims: dimensions to scale input for cropping/sampling. + Default is to scale to net input size for whole-image crop. + gpu, mean, input_scale, raw_scale, channel_swap: params for + preprocessing options. + """ + caffe.Net.__init__(self, model_file, pretrained_file) + self.set_phase_test() + + if gpu: + self.set_mode_gpu() + else: + self.set_mode_cpu() + + if mean is not None: + self.set_mean(self.inputs[0], mean) + if input_scale is not None: + self.set_input_scale(self.inputs[0], input_scale) + if raw_scale is not None: + self.set_raw_scale(self.inputs[0], raw_scale) + if channel_swap is not None: + self.set_channel_swap(self.inputs[0], channel_swap) + + self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:]) + if not image_dims: + image_dims = self.crop_dims + self.image_dims = image_dims + + + def predict(self, inputs, oversample=True): + """ + Predict classification probabilities of inputs. + + Take + inputs: iterable of (H x W x K) input ndarrays. + oversample: average predictions across center, corners, and mirrors + when True (default). Center-only prediction when False. + + Give + predictions: (N x C) ndarray of class probabilities + for N images and C classes. + """ + # Scale to standardize input dimensions. + input_ = np.zeros((len(inputs), + self.image_dims[0], self.image_dims[1], inputs[0].shape[2]), + dtype=np.float32) + for ix, in_ in enumerate(inputs): + input_[ix] = caffe.io.resize_image(in_, self.image_dims) + + if oversample: + # Generate center, corner, and mirrored crops. + input_ = caffe.io.oversample(input_, self.crop_dims) + else: + # Take center crop. + center = np.array(self.image_dims) / 2.0 + crop = np.tile(center, (1, 2))[0] + np.concatenate([ + -self.crop_dims / 2.0, + self.crop_dims / 2.0 + ]) + input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :] + + # Classify + caffe_in = np.zeros(np.array(input_.shape)[[0,3,1,2]], + dtype=np.float32) + for ix, in_ in enumerate(input_): + caffe_in[ix] = self.preprocess(self.inputs[0], in_) + out = self.forward_all(**{self.inputs[0]: caffe_in}) + predictions = out[self.outputs[0]].squeeze(axis=(2,3)) + + # For oversampling, average predictions across crops. + if oversample: + predictions = predictions.reshape((len(predictions) / 10, 10, -1)) + predictions = predictions.mean(1) + + return predictions diff --git a/caffe-crfrnn/python/caffe/detector.py b/caffe-crfrnn/python/caffe/detector.py new file mode 100644 index 00000000..f219b610 --- /dev/null +++ b/caffe-crfrnn/python/caffe/detector.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +""" +Do windowed detection by classifying a number of images/crops at once, +optionally using the selective search window proposal method. + +This implementation follows ideas in + Ross Girshick, Jeff Donahue, Trevor Darrell, Jitendra Malik. + Rich feature hierarchies for accurate object detection and semantic + segmentation. + http://arxiv.org/abs/1311.2524 + +The selective_search_ijcv_with_python code required for the selective search +proposal mode is available at + https://github.com/sergeyk/selective_search_ijcv_with_python +""" +import numpy as np +import os + +import caffe + + +class Detector(caffe.Net): + """ + Detector extends Net for windowed detection by a list of crops or + selective search proposals. + """ + def __init__(self, model_file, pretrained_file, gpu=False, mean=None, + input_scale=None, raw_scale=None, channel_swap=None, + context_pad=None): + """ + Take + gpu, mean, input_scale, raw_scale, channel_swap: params for + preprocessing options. + context_pad: amount of surrounding context to take s.t. a `context_pad` + sized border of pixels in the network input image is context, as in + R-CNN feature extraction. + """ + caffe.Net.__init__(self, model_file, pretrained_file) + self.set_phase_test() + + if gpu: + self.set_mode_gpu() + else: + self.set_mode_cpu() + + if mean is not None: + self.set_mean(self.inputs[0], mean) + if input_scale is not None: + self.set_input_scale(self.inputs[0], input_scale) + if raw_scale is not None: + self.set_raw_scale(self.inputs[0], raw_scale) + if channel_swap is not None: + self.set_channel_swap(self.inputs[0], channel_swap) + + self.configure_crop(context_pad) + + + def detect_windows(self, images_windows): + """ + Do windowed detection over given images and windows. Windows are + extracted then warped to the input dimensions of the net. + + Take + images_windows: (image filename, window list) iterable. + context_crop: size of context border to crop in pixels. + + Give + detections: list of {filename: image filename, window: crop coordinates, + predictions: prediction vector} dicts. + """ + # Extract windows. + window_inputs = [] + for image_fname, windows in images_windows: + image = caffe.io.load_image(image_fname).astype(np.float32) + for window in windows: + window_inputs.append(self.crop(image, window)) + + # Run through the net (warping windows to input dimensions). + caffe_in = np.zeros((len(window_inputs), window_inputs[0].shape[2]) + + self.blobs[self.inputs[0]].data.shape[2:], + dtype=np.float32) + for ix, window_in in enumerate(window_inputs): + caffe_in[ix] = self.preprocess(self.inputs[0], window_in) + out = self.forward_all(**{self.inputs[0]: caffe_in}) + predictions = out[self.outputs[0]].squeeze(axis=(2,3)) + + # Package predictions with images and windows. + detections = [] + ix = 0 + for image_fname, windows in images_windows: + for window in windows: + detections.append({ + 'window': window, + 'prediction': predictions[ix], + 'filename': image_fname + }) + ix += 1 + return detections + + + def detect_selective_search(self, image_fnames): + """ + Do windowed detection over Selective Search proposals by extracting + the crop and warping to the input dimensions of the net. + + Take + image_fnames: list + + Give + detections: list of {filename: image filename, window: crop coordinates, + predictions: prediction vector} dicts. + """ + import selective_search_ijcv_with_python as selective_search + # Make absolute paths so MATLAB can find the files. + image_fnames = [os.path.abspath(f) for f in image_fnames] + windows_list = selective_search.get_windows( + image_fnames, + cmd='selective_search_rcnn' + ) + # Run windowed detection on the selective search list. + return self.detect_windows(zip(image_fnames, windows_list)) + + + def crop(self, im, window): + """ + Crop a window from the image for detection. Include surrounding context + according to the `context_pad` configuration. + + Take + im: H x W x K image ndarray to crop. + window: bounding box coordinates as ymin, xmin, ymax, xmax. + + Give + crop: cropped window. + """ + # Crop window from the image. + crop = im[window[0]:window[2], window[1]:window[3]] + + if self.context_pad: + box = window.copy() + crop_size = self.blobs[self.inputs[0]].width # assumes square + scale = crop_size / (1. * crop_size - self.context_pad * 2) + # Crop a box + surrounding context. + half_h = (box[2] - box[0] + 1) / 2. + half_w = (box[3] - box[1] + 1) / 2. + center = (box[0] + half_h, box[1] + half_w) + scaled_dims = scale * np.array((-half_h, -half_w, half_h, half_w)) + box = np.round(np.tile(center, 2) + scaled_dims) + full_h = box[2] - box[0] + 1 + full_w = box[3] - box[1] + 1 + scale_h = crop_size / full_h + scale_w = crop_size / full_w + pad_y = round(max(0, -box[0]) * scale_h) # amount out-of-bounds + pad_x = round(max(0, -box[1]) * scale_w) + + # Clip box to image dimensions. + im_h, im_w = im.shape[:2] + box = np.clip(box, 0., [im_h, im_w, im_h, im_w]) + clip_h = box[2] - box[0] + 1 + clip_w = box[3] - box[1] + 1 + assert(clip_h > 0 and clip_w > 0) + crop_h = round(clip_h * scale_h) + crop_w = round(clip_w * scale_w) + if pad_y + crop_h > crop_size: + crop_h = crop_size - pad_y + if pad_x + crop_w > crop_size: + crop_w = crop_size - pad_x + + # collect with context padding and place in input + # with mean padding + context_crop = im[box[0]:box[2], box[1]:box[3]] + context_crop = caffe.io.resize_image(context_crop, (crop_h, crop_w)) + crop = self.crop_mean.copy() + crop[pad_y:(pad_y + crop_h), pad_x:(pad_x + crop_w)] = context_crop + + return crop + + + def configure_crop(self, context_pad): + """ + Configure amount of context for cropping. + If context is included, make the special input mean for context padding. + + Take + context_pad: amount of context for cropping. + """ + self.context_pad = context_pad + if self.context_pad: + raw_scale = self.raw_scale.get(self.inputs[0]) + channel_order = self.channel_swap.get(self.inputs[0]) + # Padding context crops needs the mean in unprocessed input space. + mean = self.mean.get(self.inputs[0]) + if mean is not None: + crop_mean = mean.copy().transpose((1,2,0)) + if channel_order is not None: + channel_order_inverse = [channel_order.index(i) + for i in range(crop_mean.shape[2])] + crop_mean = crop_mean[:,:, channel_order_inverse] + if raw_scale is not None: + crop_mean /= raw_scale + self.crop_mean = crop_mean + else: + self.crop_mean = np.zeros(self.blobs[self.inputs[0]].data.shape, + dtype=np.float32) diff --git a/caffe-crfrnn/python/caffe/draw.py b/caffe-crfrnn/python/caffe/draw.py new file mode 100644 index 00000000..f8631cfa --- /dev/null +++ b/caffe-crfrnn/python/caffe/draw.py @@ -0,0 +1,76 @@ +""" +Caffe network visualization: draw the NetParameter protobuffer. + +NOTE: this requires pydot>=1.0.2, which is not included in requirements.txt +since it requires graphviz and other prerequisites outside the scope of the +Caffe. +""" + +from caffe.proto import caffe_pb2 +from google.protobuf import text_format +import pydot + +# Internal layer and blob styles. +LAYER_STYLE = {'shape': 'record', 'fillcolor': '#6495ED', + 'style': 'filled'} +NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90', + 'style': 'filled'} +BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C', + 'style': 'filled'} +def get_enum_name_by_value(): + desc = caffe_pb2.LayerParameter.LayerType.DESCRIPTOR + d = {} + for k,v in desc.values_by_name.items(): + d[v.number] = k + return d + +def get_pydot_graph(caffe_net): + pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph', rankdir="BT") + pydot_nodes = {} + pydot_edges = [] + d = get_enum_name_by_value() + for layer in caffe_net.layers: + name = layer.name + layertype = d[layer.type] + if (len(layer.bottom) == 1 and len(layer.top) == 1 and + layer.bottom[0] == layer.top[0]): + # We have an in-place neuron layer. + pydot_nodes[name + '_' + layertype] = pydot.Node( + '%s (%s)' % (name, layertype), **NEURON_LAYER_STYLE) + else: + pydot_nodes[name + '_' + layertype] = pydot.Node( + '%s (%s)' % (name, layertype), **LAYER_STYLE) + for bottom_blob in layer.bottom: + pydot_nodes[bottom_blob + '_blob'] = pydot.Node( + '%s' % (bottom_blob), **BLOB_STYLE) + pydot_edges.append((bottom_blob + '_blob', name + '_' + layertype)) + for top_blob in layer.top: + pydot_nodes[top_blob + '_blob'] = pydot.Node( + '%s' % (top_blob)) + pydot_edges.append((name + '_' + layertype, top_blob + '_blob')) + # Now, add the nodes and edges to the graph. + for node in pydot_nodes.values(): + pydot_graph.add_node(node) + for edge in pydot_edges: + pydot_graph.add_edge( + pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]])) + return pydot_graph + +def draw_net(caffe_net, ext='png'): + """Draws a caffe net and returns the image string encoded using the given + extension. + + Input: + caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer. + ext: the image extension. Default 'png'. + """ + return get_pydot_graph(caffe_net).create(format=ext) + +def draw_net_to_file(caffe_net, filename): + """Draws a caffe net, and saves it to file using the format given as the + file extension. Use '.raw' to output raw text that you can manually feed + to graphviz to draw graphs. + """ + ext = filename[filename.rfind('.')+1:] + with open(filename, 'wb') as fid: + fid.write(draw_net(caffe_net, ext)) diff --git a/caffe-crfrnn/python/caffe/imagenet/ilsvrc_2012_mean.npy b/caffe-crfrnn/python/caffe/imagenet/ilsvrc_2012_mean.npy new file mode 100644 index 00000000..666082c6 Binary files /dev/null and b/caffe-crfrnn/python/caffe/imagenet/ilsvrc_2012_mean.npy differ diff --git a/caffe-crfrnn/python/caffe/io.py b/caffe-crfrnn/python/caffe/io.py new file mode 100644 index 00000000..aabcfddb --- /dev/null +++ b/caffe-crfrnn/python/caffe/io.py @@ -0,0 +1,171 @@ +import numpy as np +import skimage.io +from scipy.ndimage import zoom +from skimage.transform import resize + +from caffe.proto import caffe_pb2 + + +def load_image(filename, color=True): + """ + Load an image converting from grayscale or alpha as needed. + + Take + filename: string + color: flag for color format. True (default) loads as RGB while False + loads as intensity (if image is already grayscale). + + Give + image: an image with type np.float32 in range [0, 1] + of size (H x W x 3) in RGB or + of size (H x W x 1) in grayscale. + """ + img = skimage.img_as_float(skimage.io.imread(filename)).astype(np.float32) + if img.ndim == 2: + img = img[:, :, np.newaxis] + if color: + img = np.tile(img, (1, 1, 3)) + elif img.shape[2] == 4: + img = img[:, :, :3] + return img + + +def resize_image(im, new_dims, interp_order=1): + """ + Resize an image array with interpolation. + + Take + im: (H x W x K) ndarray + new_dims: (height, width) tuple of new dimensions. + interp_order: interpolation order, default is linear. + + Give + im: resized ndarray with shape (new_dims[0], new_dims[1], K) + """ + if im.shape[-1] == 1 or im.shape[-1] == 3: + # skimage is fast but only understands {1,3} channel images in [0, 1]. + im_min, im_max = im.min(), im.max() + im_std = (im - im_min) / (im_max - im_min) + resized_std = resize(im_std, new_dims, order=interp_order) + resized_im = resized_std * (im_max - im_min) + im_min + else: + # ndimage interpolates anything but more slowly. + scale = tuple(np.array(new_dims) / np.array(im.shape[:2])) + resized_im = zoom(im, scale + (1,), order=interp_order) + return resized_im.astype(np.float32) + + +def oversample(images, crop_dims): + """ + Crop images into the four corners, center, and their mirrored versions. + + Take + image: iterable of (H x W x K) ndarrays + crop_dims: (height, width) tuple for the crops. + + Give + crops: (10*N x H x W x K) ndarray of crops for number of inputs N. + """ + # Dimensions and center. + im_shape = np.array(images[0].shape) + crop_dims = np.array(crop_dims) + im_center = im_shape[:2] / 2.0 + + # Make crop coordinates + h_indices = (0, im_shape[0] - crop_dims[0]) + w_indices = (0, im_shape[1] - crop_dims[1]) + crops_ix = np.empty((5, 4), dtype=int) + curr = 0 + for i in h_indices: + for j in w_indices: + crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1]) + curr += 1 + crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([ + -crop_dims / 2.0, + crop_dims / 2.0 + ]) + crops_ix = np.tile(crops_ix, (2, 1)) + + # Extract crops + crops = np.empty((10 * len(images), crop_dims[0], crop_dims[1], + im_shape[-1]), dtype=np.float32) + ix = 0 + for im in images: + for crop in crops_ix: + crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :] + ix += 1 + crops[ix-5:ix] = crops[ix-5:ix, :, ::-1, :] # flip for mirrors + return crops + + +def blobproto_to_array(blob, return_diff=False): + """Convert a blob proto to an array. In default, we will just return the data, + unless return_diff is True, in which case we will return the diff. + """ + if return_diff: + return np.array(blob.diff).reshape( + blob.num, blob.channels, blob.height, blob.width) + else: + return np.array(blob.data).reshape( + blob.num, blob.channels, blob.height, blob.width) + + +def array_to_blobproto(arr, diff=None): + """Converts a 4-dimensional array to blob proto. If diff is given, also + convert the diff. You need to make sure that arr and diff have the same + shape, and this function does not do sanity check. + """ + if arr.ndim != 4: + raise ValueError('Incorrect array shape.') + blob = caffe_pb2.BlobProto() + blob.num, blob.channels, blob.height, blob.width = arr.shape; + blob.data.extend(arr.astype(float).flat) + if diff is not None: + blob.diff.extend(diff.astype(float).flat) + return blob + + +def arraylist_to_blobprotovecor_str(arraylist): + """Converts a list of arrays to a serialized blobprotovec, which could be + then passed to a network for processing. + """ + vec = caffe_pb2.BlobProtoVector() + vec.blobs.extend([array_to_blobproto(arr) for arr in arraylist]) + return vec.SerializeToString() + + +def blobprotovector_str_to_arraylist(str): + """Converts a serialized blobprotovec to a list of arrays. + """ + vec = caffe_pb2.BlobProtoVector() + vec.ParseFromString(str) + return [blobproto_to_array(blob) for blob in vec.blobs] + + +def array_to_datum(arr, label=0): + """Converts a 3-dimensional array to datum. If the array has dtype uint8, + the output data will be encoded as a string. Otherwise, the output data + will be stored in float format. + """ + if arr.ndim != 3: + raise ValueError('Incorrect array shape.') + datum = caffe_pb2.Datum() + datum.channels, datum.height, datum.width = arr.shape + if arr.dtype == np.uint8: + datum.data = arr.tostring() + else: + datum.float_data.extend(arr.flat) + datum.label = label + return datum + + +def datum_to_array(datum): + """Converts a datum to an array. Note that the label is not returned, + as one can easily get it by calling datum.label. + """ + if len(datum.data): + return np.fromstring(datum.data, dtype = np.uint8).reshape( + datum.channels, datum.height, datum.width) + else: + return np.array(datum.float_data).astype(float).reshape( + datum.channels, datum.height, datum.width) diff --git a/caffe-crfrnn/python/caffe/pycaffe.py b/caffe-crfrnn/python/caffe/pycaffe.py new file mode 100644 index 00000000..31dc1f9b --- /dev/null +++ b/caffe-crfrnn/python/caffe/pycaffe.py @@ -0,0 +1,394 @@ +""" +Wrap the internal caffe C++ module (_caffe.so) with a clean, Pythonic +interface. +""" + +from collections import OrderedDict +from itertools import izip_longest +import numpy as np + +from ._caffe import Net, SGDSolver +import caffe.io + +# We directly update methods from Net here (rather than using composition or +# inheritance) so that nets created by caffe (e.g., by SGDSolver) will +# automatically have the improved interface. + + +@property +def _Net_blobs(self): + """ + An OrderedDict (bottom to top, i.e., input to output) of network + blobs indexed by name + """ + return OrderedDict(zip(self._blob_names, self._blobs)) + + +@property +def _Net_params(self): + """ + An OrderedDict (bottom to top, i.e., input to output) of network + parameters indexed by name; each is a list of multiple blobs (e.g., + weights and biases) + """ + return OrderedDict([(name, lr.blobs) + for name, lr in zip(self._layer_names, self.layers) + if len(lr.blobs) > 0]) + +def _Net_forward(self, blobs=None, start=None, end=None, **kwargs): + """ + Forward pass: prepare inputs and run the net forward. + + Take + blobs: list of blobs to return in addition to output blobs. + kwargs: Keys are input blob names and values are blob ndarrays. + For formatting inputs for Caffe, see Net.preprocess(). + If None, input is taken from data layers. + start: optional name of layer at which to begin the forward pass + end: optional name of layer at which to finish the forward pass (inclusive) + + Give + outs: {blob name: blob ndarray} dict. + """ + if blobs is None: + blobs = [] + + if start is not None: + start_ind = list(self._layer_names).index(start) + else: + start_ind = 0 + + if end is not None: + end_ind = list(self._layer_names).index(end) + outputs = set([end] + blobs) + else: + end_ind = len(self.layers) - 1 + outputs = set(self.outputs + blobs) + + if kwargs: + if set(kwargs.keys()) != set(self.inputs): + raise Exception('Input blob arguments do not match net inputs.') + # Set input according to defined shapes and make arrays single and + # C-contiguous as Caffe expects. + for in_, blob in kwargs.iteritems(): + if blob.ndim != 4: + raise Exception('{} blob is not 4-d'.format(in_)) + if blob.shape[0] != self.blobs[in_].num: + raise Exception('Input is not batch sized') + self.blobs[in_].data[...] = blob + + self._forward(start_ind, end_ind) + + # Unpack blobs to extract + return {out: self.blobs[out].data for out in outputs} + + +def _Net_backward(self, diffs=None, start=None, end=None, **kwargs): + """ + Backward pass: prepare diffs and run the net backward. + + Take + diffs: list of diffs to return in addition to bottom diffs. + kwargs: Keys are output blob names and values are diff ndarrays. + If None, top diffs are taken from forward loss. + start: optional name of layer at which to begin the backward pass + end: optional name of layer at which to finish the backward pass (inclusive) + + Give + outs: {blob name: diff ndarray} dict. + """ + if diffs is None: + diffs = [] + + if start is not None: + start_ind = list(self._layer_names).index(start) + else: + start_ind = len(self.layers) - 1 + + if end is not None: + end_ind = list(self._layer_names).index(end) + outputs = set([end] + diffs) + else: + end_ind = 0 + outputs = set(self.inputs + diffs) + + if kwargs: + if set(kwargs.keys()) != set(self.outputs): + raise Exception('Top diff arguments do not match net outputs.') + # Set top diffs according to defined shapes and make arrays single and + # C-contiguous as Caffe expects. + for top, diff in kwargs.iteritems(): + if diff.ndim != 4: + raise Exception('{} diff is not 4-d'.format(top)) + if diff.shape[0] != self.blobs[top].num: + raise Exception('Diff is not batch sized') + self.blobs[top].diff[...] = diff + + self._backward(start_ind, end_ind) + + # Unpack diffs to extract + return {out: self.blobs[out].diff for out in outputs} + + +def _Net_forward_all(self, blobs=None, **kwargs): + """ + Run net forward in batches. + + Take + blobs: list of blobs to extract as in forward() + kwargs: Keys are input blob names and values are blob ndarrays. + Refer to forward(). + + Give + all_outs: {blob name: list of blobs} dict. + """ + # Collect outputs from batches + all_outs = {out: [] for out in set(self.outputs + (blobs or []))} + for batch in self._batch(kwargs): + outs = self.forward(blobs=blobs, **batch) + for out, out_blob in outs.iteritems(): + all_outs[out].extend(out_blob.copy()) + # Package in ndarray. + for out in all_outs: + all_outs[out] = np.asarray(all_outs[out]) + # Discard padding. + pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + if pad: + for out in all_outs: + all_outs[out] = all_outs[out][:-pad] + return all_outs + + +def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs): + """ + Run net forward + backward in batches. + + Take + blobs: list of blobs to extract as in forward() + diffs: list of diffs to extract as in backward() + kwargs: Keys are input (for forward) and output (for backward) blob names + and values are ndarrays. Refer to forward() and backward(). + Prefilled variants are called for lack of input or output blobs. + + Give + all_blobs: {blob name: blob ndarray} dict. + all_diffs: {blob name: diff ndarray} dict. + """ + # Batch blobs and diffs. + all_outs = {out: [] for out in set(self.outputs + (blobs or []))} + all_diffs = {diff: [] for diff in set(self.inputs + (diffs or []))} + forward_batches = self._batch({in_: kwargs[in_] + for in_ in self.inputs if in_ in kwargs}) + backward_batches = self._batch({out: kwargs[out] + for out in self.outputs if out in kwargs}) + # Collect outputs from batches (and heed lack of forward/backward batches). + for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}): + batch_blobs = self.forward(blobs=blobs, **fb) + batch_diffs = self.backward(diffs=diffs, **bb) + for out, out_blobs in batch_blobs.iteritems(): + all_outs[out].extend(out_blobs) + for diff, out_diffs in batch_diffs.iteritems(): + all_diffs[diff].extend(out_diffs) + # Package in ndarray. + for out, diff in zip(all_outs, all_diffs): + all_outs[out] = np.asarray(all_outs[out]) + all_diffs[diff] = np.asarray(all_diffs[diff]) + # Discard padding at the end and package in ndarray. + pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + if pad: + for out, diff in zip(all_outs, all_diffs): + all_outs[out] = all_outs[out][:-pad] + all_diffs[diff] = all_diffs[diff][:-pad] + return all_outs, all_diffs + + +def _Net_set_mean(self, input_, mean, mode='elementwise'): + """ + Set the mean to subtract for data centering. + + Take + input_: which input to assign this mean. + mean: mean K x H x W ndarray (input dimensional or broadcastable) + mode: elementwise = use the whole mean (and check dimensions) + channel = channel constant (e.g. mean pixel instead of mean image) + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + in_shape = self.blobs[input_].data.shape + if mode == 'elementwise': + if mean.shape[1:] != in_shape[2:]: + # Resize mean (which requires H x W x K input). + mean = caffe.io.resize_image(mean.transpose((1,2,0)), + in_shape[2:]).transpose((2,0,1)) + self.mean[input_] = mean + elif mode == 'channel': + self.mean[input_] = mean.mean(1).mean(1).reshape((in_shape[1], 1, 1)) + else: + raise Exception('Mode not in {}'.format(['elementwise', 'channel'])) + + +def _Net_set_input_scale(self, input_, scale): + """ + Set the scale of preprocessed inputs s.t. the blob = blob * scale. + N.B. input_scale is done AFTER mean subtraction and other preprocessing + while raw_scale is done BEFORE. + + Take + input_: which input to assign this scale factor + scale: scale coefficient + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + self.input_scale[input_] = scale + + +def _Net_set_raw_scale(self, input_, scale): + """ + Set the scale of raw features s.t. the input blob = input * scale. + While Python represents images in [0, 1], certain Caffe models + like CaffeNet and AlexNet represent images in [0, 255] so the raw_scale + of these models must be 255. + + Take + input_: which input to assign this scale factor + scale: scale coefficient + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + self.raw_scale[input_] = scale + + +def _Net_set_channel_swap(self, input_, order): + """ + Set the input channel order for e.g. RGB to BGR conversion + as needed for the reference ImageNet model. + + Take + input_: which input to assign this channel order + order: the order to take the channels. + (2,1,0) maps RGB to BGR for example. + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + self.channel_swap[input_] = order + + +def _Net_preprocess(self, input_name, input_): + """ + Format input for Caffe: + - convert to single + - resize to input dimensions (preserving number of channels) + - reorder channels (for instance color to BGR) + - scale raw input (e.g. from [0, 1] to [0, 255] for ImageNet models) + - transpose dimensions to K x H x W + - subtract mean + - scale feature + + Take + input_name: name of input blob to preprocess for + input_: (H' x W' x K) ndarray + + Give + caffe_inputs: (K x H x W) ndarray + """ + caffe_in = input_.astype(np.float32, copy=False) + mean = self.mean.get(input_name) + input_scale = self.input_scale.get(input_name) + raw_scale = self.raw_scale.get(input_name) + channel_order = self.channel_swap.get(input_name) + in_size = self.blobs[input_name].data.shape[2:] + if caffe_in.shape[:2] != in_size: + caffe_in = caffe.io.resize_image(caffe_in, in_size) + if channel_order is not None: + caffe_in = caffe_in[:, :, channel_order] + caffe_in = caffe_in.transpose((2, 0, 1)) + if raw_scale is not None: + caffe_in *= raw_scale + if mean is not None: + caffe_in -= mean + if input_scale is not None: + caffe_in *= input_scale + return caffe_in + + +def _Net_deprocess(self, input_name, input_): + """ + Invert Caffe formatting; see Net.preprocess(). + """ + decaf_in = input_.copy().squeeze() + mean = self.mean.get(input_name) + input_scale = self.input_scale.get(input_name) + raw_scale = self.raw_scale.get(input_name) + channel_order = self.channel_swap.get(input_name) + if input_scale is not None: + decaf_in /= input_scale + if mean is not None: + decaf_in += mean + if raw_scale is not None: + decaf_in /= raw_scale + decaf_in = decaf_in.transpose((1,2,0)) + if channel_order is not None: + channel_order_inverse = [channel_order.index(i) + for i in range(decaf_in.shape[2])] + decaf_in = decaf_in[:, :, channel_order_inverse] + return decaf_in + + +def _Net_set_input_arrays(self, data, labels): + """ + Set input arrays of the in-memory MemoryDataLayer. + (Note: this is only for networks declared with the memory data layer.) + """ + if labels.ndim == 1: + labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis, + np.newaxis]) + return self._set_input_arrays(data, labels) + + +def _Net_batch(self, blobs): + """ + Batch blob lists according to net's batch size. + + Take + blobs: Keys blob names and values are lists of blobs (of any length). + Naturally, all the lists should have the same length. + + Give (yield) + batch: {blob name: list of blobs} dict for a single batch. + """ + num = len(blobs.itervalues().next()) + batch_size = self.blobs.itervalues().next().num + remainder = num % batch_size + num_batches = num / batch_size + + # Yield full batches. + for b in range(num_batches): + i = b * batch_size + yield {name: blobs[name][i:i + batch_size] for name in blobs} + + # Yield last padded batch, if any. + if remainder > 0: + padded_batch = {} + for name in blobs: + padding = np.zeros((batch_size - remainder,) + + blobs[name].shape[1:]) + padded_batch[name] = np.concatenate([blobs[name][-remainder:], + padding]) + yield padded_batch + + +# Attach methods to Net. +Net.blobs = _Net_blobs +Net.params = _Net_params +Net.forward = _Net_forward +Net.backward = _Net_backward +Net.forward_all = _Net_forward_all +Net.forward_backward_all = _Net_forward_backward_all +Net.set_mean = _Net_set_mean +Net.set_input_scale = _Net_set_input_scale +Net.set_raw_scale = _Net_set_raw_scale +Net.set_channel_swap = _Net_set_channel_swap +Net.preprocess = _Net_preprocess +Net.deprocess = _Net_deprocess +Net.set_input_arrays = _Net_set_input_arrays +Net._batch = _Net_batch diff --git a/caffe-crfrnn/python/caffe/segmenter.py b/caffe-crfrnn/python/caffe/segmenter.py new file mode 100644 index 00000000..afbafe88 --- /dev/null +++ b/caffe-crfrnn/python/caffe/segmenter.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +""" +Segmenter is an image segmentation specialization of Net. +""" + +import numpy as np + +import caffe + + +class Segmenter(caffe.Net): + """ + Segmenter + """ + def __init__(self, model_file, pretrained_file, + gpu=False): + """ + """ + caffe.Net.__init__(self, model_file, pretrained_file) + self.set_phase_test() + + if gpu: + self.set_mode_gpu() + self.set_device(0) + else: + self.set_mode_cpu() + + + def predict(self, inputs): + """ + Assume that the input is a 500 x 500 image BRG layout with + correct padding as necessary to make it 500 x 500. + """ + + input_ = np.zeros((len(inputs), + 500, 500, inputs[0].shape[2]), + dtype=np.float32) + for ix, in_ in enumerate(inputs): + input_[ix] = in_ + + # Segment + caffe_in = np.zeros(np.array(input_.shape)[[0,3,1,2]], + dtype=np.float32) + for ix, in_ in enumerate(input_): + caffe_in[ix] = in_.transpose((2, 0, 1)) + out = self.forward_all(**{self.inputs[0]: caffe_in}) + predictions = out[self.outputs[0]] + + return predictions[0].argmax(axis=0).astype(np.uint8) diff --git a/caffe-crfrnn/python/classify.py b/caffe-crfrnn/python/classify.py new file mode 100755 index 00000000..d435a572 --- /dev/null +++ b/caffe-crfrnn/python/classify.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +""" +classify.py is an out-of-the-box image classifer callable from the command line. + +By default it configures and runs the Caffe reference ImageNet model. +""" +import numpy as np +import os +import sys +import argparse +import glob +import time + +import caffe + + +def main(argv): + pycaffe_dir = os.path.dirname(__file__) + + parser = argparse.ArgumentParser() + # Required arguments: input and output files. + parser.add_argument( + "input_file", + help="Input image, directory, or npy." + ) + parser.add_argument( + "output_file", + help="Output npy filename." + ) + # Optional arguments. + parser.add_argument( + "--model_def", + default=os.path.join(pycaffe_dir, + "../models/bvlc_reference_caffenet/deploy.prototxt"), + help="Model definition file." + ) + parser.add_argument( + "--pretrained_model", + default=os.path.join(pycaffe_dir, + "../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel"), + help="Trained model weights file." + ) + parser.add_argument( + "--gpu", + action='store_true', + help="Switch for gpu computation." + ) + parser.add_argument( + "--center_only", + action='store_true', + help="Switch for prediction from center crop alone instead of " + + "averaging predictions across crops (default)." + ) + parser.add_argument( + "--images_dim", + default='256,256', + help="Canonical 'height,width' dimensions of input images." + ) + parser.add_argument( + "--mean_file", + default=os.path.join(pycaffe_dir, + 'caffe/imagenet/ilsvrc_2012_mean.npy'), + help="Data set image mean of [Channels x Height x Width] dimensions " + + "(numpy array). Set to '' for no mean subtraction." + ) + parser.add_argument( + "--input_scale", + type=float, + help="Multiply input features by this scale to finish preprocessing." + ) + parser.add_argument( + "--raw_scale", + type=float, + default=255.0, + help="Multiply raw input by this scale before preprocessing." + ) + parser.add_argument( + "--channel_swap", + default='2,1,0', + help="Order to permute input channels. The default converts " + + "RGB -> BGR since BGR is the Caffe default by way of OpenCV." + ) + parser.add_argument( + "--ext", + default='jpg', + help="Image file extension to take as input when a directory " + + "is given as the input file." + ) + args = parser.parse_args() + + image_dims = [int(s) for s in args.images_dim.split(',')] + + mean, channel_swap = None, None + if args.mean_file: + mean = np.load(args.mean_file) + if args.channel_swap: + channel_swap = [int(s) for s in args.channel_swap.split(',')] + + # Make classifier. + classifier = caffe.Classifier(args.model_def, args.pretrained_model, + image_dims=image_dims, gpu=args.gpu, mean=mean, + input_scale=args.input_scale, raw_scale=args.raw_scale, + channel_swap=channel_swap) + + if args.gpu: + print 'GPU mode' + + # Load numpy array (.npy), directory glob (*.jpg), or image file. + args.input_file = os.path.expanduser(args.input_file) + if args.input_file.endswith('npy'): + inputs = np.load(args.input_file) + elif os.path.isdir(args.input_file): + inputs =[caffe.io.load_image(im_f) + for im_f in glob.glob(args.input_file + '/*.' + args.ext)] + else: + inputs = [caffe.io.load_image(args.input_file)] + + print "Classifying %d inputs." % len(inputs) + + # Classify. + start = time.time() + predictions = classifier.predict(inputs, not args.center_only) + print "Done in %.2f s." % (time.time() - start) + + # Save + np.save(args.output_file, predictions) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/caffe-crfrnn/python/detect.py b/caffe-crfrnn/python/detect.py new file mode 100755 index 00000000..b67b500a --- /dev/null +++ b/caffe-crfrnn/python/detect.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +""" +detector.py is an out-of-the-box windowed detector +callable from the command line. + +By default it configures and runs the Caffe reference ImageNet model. +Note that this model was trained for image classification and not detection, +and finetuning for detection can be expected to improve results. + +The selective_search_ijcv_with_python code required for the selective search +proposal mode is available at + https://github.com/sergeyk/selective_search_ijcv_with_python + +TODO: +- batch up image filenames as well: don't want to load all of them into memory +- come up with a batching scheme that preserved order / keeps a unique ID +""" +import numpy as np +import pandas as pd +import os +import argparse +import time + +import caffe + +CROP_MODES = ['list', 'selective_search'] +COORD_COLS = ['ymin', 'xmin', 'ymax', 'xmax'] + + +def main(argv): + pycaffe_dir = os.path.dirname(__file__) + + parser = argparse.ArgumentParser() + # Required arguments: input and output. + parser.add_argument( + "input_file", + help="Input txt/csv filename. If .txt, must be list of filenames.\ + If .csv, must be comma-separated file with header\ + 'filename, xmin, ymin, xmax, ymax'" + ) + parser.add_argument( + "output_file", + help="Output h5/csv filename. Format depends on extension." + ) + # Optional arguments. + parser.add_argument( + "--model_def", + default=os.path.join(pycaffe_dir, + "../models/bvlc_reference_caffenet/deploy.prototxt.prototxt"), + help="Model definition file." + ) + parser.add_argument( + "--pretrained_model", + default=os.path.join(pycaffe_dir, + "../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel"), + help="Trained model weights file." + ) + parser.add_argument( + "--crop_mode", + default="selective_search", + choices=CROP_MODES, + help="How to generate windows for detection." + ) + parser.add_argument( + "--gpu", + action='store_true', + help="Switch for gpu computation." + ) + parser.add_argument( + "--mean_file", + default=os.path.join(pycaffe_dir, + 'caffe/imagenet/ilsvrc_2012_mean.npy'), + help="Data set image mean of H x W x K dimensions (numpy array). " + + "Set to '' for no mean subtraction." + ) + parser.add_argument( + "--input_scale", + type=float, + help="Multiply input features by this scale to finish preprocessing." + ) + parser.add_argument( + "--raw_scale", + type=float, + default=255.0, + help="Multiply raw input by this scale before preprocessing." + ) + parser.add_argument( + "--channel_swap", + default='2,1,0', + help="Order to permute input channels. The default converts " + + "RGB -> BGR since BGR is the Caffe default by way of OpenCV." + + ) + parser.add_argument( + "--context_pad", + type=int, + default='16', + help="Amount of surrounding context to collect in input window." + ) + args = parser.parse_args() + + mean, channel_swap = None, None + if args.mean_file: + mean = np.load(args.mean_file) + if args.channel_swap: + channel_swap = [int(s) for s in args.channel_swap.split(',')] + + # Make detector. + detector = caffe.Detector(args.model_def, args.pretrained_model, + gpu=args.gpu, mean=mean, + input_scale=args.input_scale, raw_scale=args.raw_scale, + channel_swap=channel_swap, + context_pad=args.context_pad) + + if args.gpu: + print 'GPU mode' + + # Load input. + t = time.time() + print('Loading input...') + if args.input_file.lower().endswith('txt'): + with open(args.input_file) as f: + inputs = [_.strip() for _ in f.readlines()] + elif args.input_file.lower().endswith('csv'): + inputs = pd.read_csv(args.input_file, sep=',', dtype={'filename': str}) + inputs.set_index('filename', inplace=True) + else: + raise Exception("Unknown input file type: not in txt or csv.") + + # Detect. + if args.crop_mode == 'list': + # Unpack sequence of (image filename, windows). + images_windows = [ + (ix, inputs.iloc[np.where(inputs.index == ix)][COORD_COLS].values) + for ix in inputs.index.unique() + ] + detections = detector.detect_windows(images_windows) + else: + detections = detector.detect_selective_search(inputs) + print("Processed {} windows in {:.3f} s.".format(len(detections), + time.time() - t)) + + # Collect into dataframe with labeled fields. + df = pd.DataFrame(detections) + df.set_index('filename', inplace=True) + df[COORD_COLS] = pd.DataFrame( + data=np.vstack(df['window']), index=df.index, columns=COORD_COLS) + del(df['window']) + + # Save results. + t = time.time() + if args.output_file.lower().endswith('csv'): + # csv + # Enumerate the class probabilities. + class_cols = ['class{}'.format(x) for x in range(NUM_OUTPUT)] + df[class_cols] = pd.DataFrame( + data=np.vstack(df['feat']), index=df.index, columns=class_cols) + df.to_csv(args.output_file, cols=COORD_COLS + class_cols) + else: + # h5 + df.to_hdf(args.output_file, 'df', mode='w') + print("Saved to {} in {:.3f} s.".format(args.output_file, + time.time() - t)) + + +if __name__ == "__main__": + import sys + main(sys.argv) diff --git a/caffe-crfrnn/python/draw_net.py b/caffe-crfrnn/python/draw_net.py new file mode 100755 index 00000000..ba488294 --- /dev/null +++ b/caffe-crfrnn/python/draw_net.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +""" +Draw a graph of the net architecture. +""" +import os +from google.protobuf import text_format + +import caffe, caffe.draw +from caffe.proto import caffe_pb2 + + +def main(argv): + if len(argv) != 3: + print 'Usage: %s input_net_proto_file output_image_file' % \ + os.path.basename(sys.argv[0]) + else: + net = caffe_pb2.NetParameter() + text_format.Merge(open(sys.argv[1]).read(), net) + print 'Drawing net to %s' % sys.argv[2] + caffe.draw.draw_net_to_file(net, sys.argv[2]) + + +if __name__ == '__main__': + import sys + main(sys.argv) diff --git a/caffe-crfrnn/python/requirements.txt b/caffe-crfrnn/python/requirements.txt new file mode 100644 index 00000000..4c35dcb0 --- /dev/null +++ b/caffe-crfrnn/python/requirements.txt @@ -0,0 +1,15 @@ +Cython>=0.19.2 +numpy>=1.7.1 +scipy>=0.13.2 +scikit-image>=0.9.3 +scikit-learn>=0.14.1 +matplotlib>=1.3.1 +ipython>=1.1.0 +h5py>=2.2.0 +leveldb>=0.191 +networkx>=1.8.1 +nose>=1.3.0 +pandas>=0.12.0 +python-dateutil>=1.4,<2 +protobuf>=2.5.0 +python-gflags>=2.0 diff --git a/caffe-crfrnn/python/segdemo.py b/caffe-crfrnn/python/segdemo.py new file mode 100644 index 00000000..9a1a3844 --- /dev/null +++ b/caffe-crfrnn/python/segdemo.py @@ -0,0 +1,87 @@ +caffe_root = '/home/sadeep/Desktop/crf-rnn-web-demo/caffe-fcn-sadeep/' +import sys +sys.path.insert(0,caffe_root+'python') + +import os +import time +import cPickle +import datetime +import logging +import flask +import werkzeug +import optparse +import tornado.wsgi +import tornado.httpserver +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from PIL import Image +import cStringIO as StringIO +import exifutil + +import caffe + + + +MODEL_FILE = '/home/sadeep/Desktop/crf-rnn-web-demo/caffe-fcn-sadeep/models/crf_rnn/fcn-8s-pascal-deploy.prototxt' +PRETRAINED = '/home/sadeep/Desktop/crf-rnn-web-demo/caffe-fcn-sadeep/models/crf_rnn/fcn-8s-pascal.caffemodel' +IMAGE_FILE = '/home/sadeep/Desktop/crf-rnn-web-demo/caffe-fcn-sadeep/models/crf_rnn/2007_000033.jpg' + +pallete = [0,0,0, + 128,0,0, + 0,128,0, + 128,128,0, + 0,0,128, + 128,0,128, + 0,128,128, + 128,128,128, + 64,0,0, + 192,0,0, + 64,128,0, + 192,128,0, + 64,0,128, + 192,0,128, + 64,128,128, + 192,128,128, + 0,64,0, + 128,64,0, + 0,192,0, + 128,192,0, + 0,64,128, + 128,64,128, + 0,192,128, + 128,192,128, + 64,64,0, + 192,64,0, + 64,192,0, + 192,192,0] + +net = caffe.Segmenter(MODEL_FILE, PRETRAINED, gpu=False) + +input_image = 255 * exifutil.open_oriented_im(IMAGE_FILE) + + +# Mean values in BGR format +mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32) +reshaped_mean_vec = mean_vec.reshape(1,1,3); + +# Rearrange channels to form BGR +im = input_image[:,:,::-1] + +# Subtract mean +im = im - reshaped_mean_vec + +# Pad as necessary +cur_h, cur_w, cur_c = im.shape +pad_h = 500 - cur_h +pad_w = 500 - cur_w +im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode = 'constant', constant_values = 0) + +# Get predictions +segmentation = net.predict([im]) + +output_im = Image.fromarray(segmentation) + +output_im.putpalette(pallete); + +output_im.save('hahasadeep.png') \ No newline at end of file diff --git a/caffe-crfrnn/scripts/build_docs.sh b/caffe-crfrnn/scripts/build_docs.sh new file mode 100755 index 00000000..f8ace0ea --- /dev/null +++ b/caffe-crfrnn/scripts/build_docs.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Build documentation for display in web browser. + +PORT=${1:-4000} + +echo "usage: build.sh [port]" + +# Find the docs dir, no matter where the script is called +ROOT_DIR="$( cd "$(dirname "$0")"/.. ; pwd -P )" +cd $ROOT_DIR + +# Gather docs. +scripts/gather_examples.sh + +# Generate developer docs. +make docs + +# Display docs using web server. +cd docs +jekyll serve -w -s . -d _site --port=$PORT diff --git a/caffe-crfrnn/scripts/copy_notebook.py b/caffe-crfrnn/scripts/copy_notebook.py new file mode 100755 index 00000000..e4c6385b --- /dev/null +++ b/caffe-crfrnn/scripts/copy_notebook.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +""" +Takes as arguments: +1. the path to a JSON file (such as an IPython notebook). +2. the path to output file + +If 'metadata' dict in the JSON file contains 'include_in_docs': true, +then copies the file to output file, appending the 'metadata' property +as YAML front-matter, adding the field 'category' with value 'notebook'. +""" +import os +import sys +import json + +filename = sys.argv[1] +output_filename = sys.argv[2] +content = json.load(open(filename)) + +if 'include_in_docs' in content['metadata'] and content['metadata']['include_in_docs']: + yaml_frontmatter = ['---'] + for key, val in content['metadata'].iteritems(): + if key == 'example_name': + key = 'title' + if val == '': + val = os.path.basename(filename) + yaml_frontmatter.append('{}: {}'.format(key, val)) + yaml_frontmatter += ['category: notebook'] + yaml_frontmatter += ['original_path: ' + filename] + + with open(output_filename, 'w') as fo: + fo.write('\n'.join(yaml_frontmatter + ['---']) + '\n') + fo.write(open(filename).read()) diff --git a/caffe-crfrnn/scripts/cpp_lint.py b/caffe-crfrnn/scripts/cpp_lint.py new file mode 100755 index 00000000..1b7c6c05 --- /dev/null +++ b/caffe-crfrnn/scripts/cpp_lint.py @@ -0,0 +1,4868 @@ +#!/usr/bin/python +# +# Copyright (c) 2009 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Does google-lint on c++ files. + +The goal of this script is to identify places in the code that *may* +be in non-compliance with google style. It does not attempt to fix +up these problems -- the point is to educate. It does also not +attempt to find all problems, or to ensure that everything it does +find is legitimately a problem. + +In particular, we can get very confused by /* and // inside strings! +We do a small hack, which is to ignore //'s with "'s after them on the +same line, but it is far from perfect (in either direction). +""" + +import codecs +import copy +import getopt +import math # for log +import os +import re +import sre_compile +import string +import sys +import unicodedata + + +_USAGE = """ +Syntax: cpp_lint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...] + [--counting=total|toplevel|detailed] [--root=subdir] + [--linelength=digits] + [file] ... + + The style guidelines this tries to follow are those in + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml + + Every problem is given a confidence score from 1-5, with 5 meaning we are + certain of the problem, and 1 meaning it could be a legitimate construct. + This will miss some errors, and is not a substitute for a code review. + + To suppress false-positive errors of a certain category, add a + 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*) + suppresses errors of all categories on that line. + + The files passed in will be linted; at least one file must be provided. + Default linted extensions are .cc, .cpp, .cu, .cuh and .h. Change the + extensions with the --extensions flag. + + Flags: + + output=vs7 + By default, the output is formatted to ease emacs parsing. Visual Studio + compatible output (vs7) may also be used. Other formats are unsupported. + + verbose=# + Specify a number 0-5 to restrict errors to certain verbosity levels. + + filter=-x,+y,... + Specify a comma-separated list of category-filters to apply: only + error messages whose category names pass the filters will be printed. + (Category names are printed with the message and look like + "[whitespace/indent]".) Filters are evaluated left to right. + "-FOO" and "FOO" means "do not print categories that start with FOO". + "+FOO" means "do print categories that start with FOO". + + Examples: --filter=-whitespace,+whitespace/braces + --filter=whitespace,runtime/printf,+runtime/printf_format + --filter=-,+build/include_what_you_use + + To see a list of all the categories used in cpplint, pass no arg: + --filter= + + counting=total|toplevel|detailed + The total number of errors found is always printed. If + 'toplevel' is provided, then the count of errors in each of + the top-level categories like 'build' and 'whitespace' will + also be printed. If 'detailed' is provided, then a count + is provided for each category like 'build/class'. + + root=subdir + The root directory used for deriving header guard CPP variable. + By default, the header guard CPP variable is calculated as the relative + path to the directory that contains .git, .hg, or .svn. When this flag + is specified, the relative path is calculated from the specified + directory. If the specified directory does not exist, this flag is + ignored. + + Examples: + Assuing that src/.git exists, the header guard CPP variables for + src/chrome/browser/ui/browser.h are: + + No flag => CHROME_BROWSER_UI_BROWSER_H_ + --root=chrome => BROWSER_UI_BROWSER_H_ + --root=chrome/browser => UI_BROWSER_H_ + + linelength=digits + This is the allowed line length for the project. The default value is + 80 characters. + + Examples: + --linelength=120 + + extensions=extension,extension,... + The allowed file extensions that cpplint will check + + Examples: + --extensions=hpp,cpp +""" + +# We categorize each error message we print. Here are the categories. +# We want an explicit list so we can list them all in cpplint --filter=. +# If you add a new error message with a new category, add it to the list +# here! cpplint_unittest.py should tell you if you forget to do this. +_ERROR_CATEGORIES = [ + 'build/class', + 'build/deprecated', + 'build/endif_comment', + 'build/explicit_make_pair', + 'build/forward_decl', + 'build/header_guard', + 'build/include', + 'build/include_alpha', + 'build/include_dir', + 'build/include_order', + 'build/include_what_you_use', + 'build/namespaces', + 'build/printf_format', + 'build/storage_class', + 'caffe/alt_fn', + 'caffe/data_layer_setup', + 'caffe/random_fn', + 'legal/copyright', + 'readability/alt_tokens', + 'readability/braces', + 'readability/casting', + 'readability/check', + 'readability/constructors', + 'readability/fn_size', + 'readability/function', + 'readability/multiline_comment', + 'readability/multiline_string', + 'readability/namespace', + 'readability/nolint', + 'readability/nul', + 'readability/streams', + 'readability/todo', + 'readability/utf8', + 'runtime/arrays', + 'runtime/casting', + 'runtime/explicit', + 'runtime/int', + 'runtime/init', + 'runtime/invalid_increment', + 'runtime/member_string_references', + 'runtime/memset', + 'runtime/operator', + 'runtime/printf', + 'runtime/printf_format', + 'runtime/references', + 'runtime/string', + 'runtime/threadsafe_fn', + 'runtime/vlog', + 'whitespace/blank_line', + 'whitespace/braces', + 'whitespace/comma', + 'whitespace/comments', + 'whitespace/empty_conditional_body', + 'whitespace/empty_loop_body', + 'whitespace/end_of_line', + 'whitespace/ending_newline', + 'whitespace/forcolon', + 'whitespace/indent', + 'whitespace/line_length', + 'whitespace/newline', + 'whitespace/operators', + 'whitespace/parens', + 'whitespace/semicolon', + 'whitespace/tab', + 'whitespace/todo' + ] + +# The default state of the category filter. This is overrided by the --filter= +# flag. By default all errors are on, so only add here categories that should be +# off by default (i.e., categories that must be enabled by the --filter= flags). +# All entries here should start with a '-' or '+', as in the --filter= flag. +_DEFAULT_FILTERS = [ + '-build/include_dir', + '-readability/todo', + ] + +# We used to check for high-bit characters, but after much discussion we +# decided those were OK, as long as they were in UTF-8 and didn't represent +# hard-coded international strings, which belong in a separate i18n file. + + +# C++ headers +_CPP_HEADERS = frozenset([ + # Legacy + 'algobase.h', + 'algo.h', + 'alloc.h', + 'builtinbuf.h', + 'bvector.h', + 'complex.h', + 'defalloc.h', + 'deque.h', + 'editbuf.h', + 'fstream.h', + 'function.h', + 'hash_map', + 'hash_map.h', + 'hash_set', + 'hash_set.h', + 'hashtable.h', + 'heap.h', + 'indstream.h', + 'iomanip.h', + 'iostream.h', + 'istream.h', + 'iterator.h', + 'list.h', + 'map.h', + 'multimap.h', + 'multiset.h', + 'ostream.h', + 'pair.h', + 'parsestream.h', + 'pfstream.h', + 'procbuf.h', + 'pthread_alloc', + 'pthread_alloc.h', + 'rope', + 'rope.h', + 'ropeimpl.h', + 'set.h', + 'slist', + 'slist.h', + 'stack.h', + 'stdiostream.h', + 'stl_alloc.h', + 'stl_relops.h', + 'streambuf.h', + 'stream.h', + 'strfile.h', + 'strstream.h', + 'tempbuf.h', + 'tree.h', + 'type_traits.h', + 'vector.h', + # 17.6.1.2 C++ library headers + 'algorithm', + 'array', + 'atomic', + 'bitset', + 'chrono', + 'codecvt', + 'complex', + 'condition_variable', + 'deque', + 'exception', + 'forward_list', + 'fstream', + 'functional', + 'future', + 'initializer_list', + 'iomanip', + 'ios', + 'iosfwd', + 'iostream', + 'istream', + 'iterator', + 'limits', + 'list', + 'locale', + 'map', + 'memory', + 'mutex', + 'new', + 'numeric', + 'ostream', + 'queue', + 'random', + 'ratio', + 'regex', + 'set', + 'sstream', + 'stack', + 'stdexcept', + 'streambuf', + 'string', + 'strstream', + 'system_error', + 'thread', + 'tuple', + 'typeindex', + 'typeinfo', + 'type_traits', + 'unordered_map', + 'unordered_set', + 'utility', + 'valarray', + 'vector', + # 17.6.1.2 C++ headers for C library facilities + 'cassert', + 'ccomplex', + 'cctype', + 'cerrno', + 'cfenv', + 'cfloat', + 'cinttypes', + 'ciso646', + 'climits', + 'clocale', + 'cmath', + 'csetjmp', + 'csignal', + 'cstdalign', + 'cstdarg', + 'cstdbool', + 'cstddef', + 'cstdint', + 'cstdio', + 'cstdlib', + 'cstring', + 'ctgmath', + 'ctime', + 'cuchar', + 'cwchar', + 'cwctype', + ]) + +# Assertion macros. These are defined in base/logging.h and +# testing/base/gunit.h. Note that the _M versions need to come first +# for substring matching to work. +_CHECK_MACROS = [ + 'DCHECK', 'CHECK', + 'EXPECT_TRUE_M', 'EXPECT_TRUE', + 'ASSERT_TRUE_M', 'ASSERT_TRUE', + 'EXPECT_FALSE_M', 'EXPECT_FALSE', + 'ASSERT_FALSE_M', 'ASSERT_FALSE', + ] + +# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE +_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS]) + +for op, replacement in [('==', 'EQ'), ('!=', 'NE'), + ('>=', 'GE'), ('>', 'GT'), + ('<=', 'LE'), ('<', 'LT')]: + _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement + _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE_M'][op] = 'EXPECT_%s_M' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE_M'][op] = 'ASSERT_%s_M' % replacement + +for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'), + ('>=', 'LT'), ('>', 'LE'), + ('<=', 'GT'), ('<', 'GE')]: + _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement + _CHECK_REPLACEMENT['EXPECT_FALSE_M'][op] = 'EXPECT_%s_M' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE_M'][op] = 'ASSERT_%s_M' % inv_replacement + +# Alternative tokens and their replacements. For full list, see section 2.5 +# Alternative tokens [lex.digraph] in the C++ standard. +# +# Digraphs (such as '%:') are not included here since it's a mess to +# match those on a word boundary. +_ALT_TOKEN_REPLACEMENT = { + 'and': '&&', + 'bitor': '|', + 'or': '||', + 'xor': '^', + 'compl': '~', + 'bitand': '&', + 'and_eq': '&=', + 'or_eq': '|=', + 'xor_eq': '^=', + 'not': '!', + 'not_eq': '!=' + } + +# Compile regular expression that matches all the above keywords. The "[ =()]" +# bit is meant to avoid matching these keywords outside of boolean expressions. +# +# False positives include C-style multi-line comments and multi-line strings +# but those have always been troublesome for cpplint. +_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile( + r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)') + + +# These constants define types of headers for use with +# _IncludeState.CheckNextIncludeOrder(). +_C_SYS_HEADER = 1 +_CPP_SYS_HEADER = 2 +_LIKELY_MY_HEADER = 3 +_POSSIBLE_MY_HEADER = 4 +_OTHER_HEADER = 5 + +# These constants define the current inline assembly state +_NO_ASM = 0 # Outside of inline assembly block +_INSIDE_ASM = 1 # Inside inline assembly block +_END_ASM = 2 # Last line of inline assembly block +_BLOCK_ASM = 3 # The whole block is an inline assembly block + +# Match start of assembly blocks +_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)' + r'(?:\s+(volatile|__volatile__))?' + r'\s*[{(]') + + +_regexp_compile_cache = {} + +# Finds occurrences of NOLINT[_NEXT_LINE] or NOLINT[_NEXT_LINE](...). +_RE_SUPPRESSION = re.compile(r'\bNOLINT(_NEXT_LINE)?\b(\([^)]*\))?') + +# {str, set(int)}: a map from error categories to sets of linenumbers +# on which those errors are expected and should be suppressed. +_error_suppressions = {} + +# Finds Copyright. +_RE_COPYRIGHT = re.compile(r'Copyright') + +# The root directory used for deriving header guard CPP variable. +# This is set by --root flag. +_root = None + +# The allowed line length of files. +# This is set by --linelength flag. +_line_length = 80 + +# The allowed extensions for file names +# This is set by --extensions flag. +_valid_extensions = set(['cc', 'h', 'cpp', 'hpp', 'cu', 'cuh']) + +def ParseNolintSuppressions(filename, raw_line, linenum, error): + """Updates the global list of error-suppressions. + + Parses any NOLINT comments on the current line, updating the global + error_suppressions store. Reports an error if the NOLINT comment + was malformed. + + Args: + filename: str, the name of the input file. + raw_line: str, the line of input text, with comments. + linenum: int, the number of the current line. + error: function, an error handler. + """ + # FIXME(adonovan): "NOLINT(" is misparsed as NOLINT(*). + matched = _RE_SUPPRESSION.search(raw_line) + if matched: + if matched.group(1) == '_NEXT_LINE': + linenum += 1 + category = matched.group(2) + if category in (None, '(*)'): # => "suppress all" + _error_suppressions.setdefault(None, set()).add(linenum) + else: + if category.startswith('(') and category.endswith(')'): + category = category[1:-1] + if category in _ERROR_CATEGORIES: + _error_suppressions.setdefault(category, set()).add(linenum) + else: + error(filename, linenum, 'readability/nolint', 5, + 'Unknown NOLINT error category: %s' % category) + + +def ResetNolintSuppressions(): + "Resets the set of NOLINT suppressions to empty." + _error_suppressions.clear() + + +def IsErrorSuppressedByNolint(category, linenum): + """Returns true if the specified error category is suppressed on this line. + + Consults the global error_suppressions map populated by + ParseNolintSuppressions/ResetNolintSuppressions. + + Args: + category: str, the category of the error. + linenum: int, the current line number. + Returns: + bool, True iff the error should be suppressed due to a NOLINT comment. + """ + return (linenum in _error_suppressions.get(category, set()) or + linenum in _error_suppressions.get(None, set())) + +def Match(pattern, s): + """Matches the string with the pattern, caching the compiled regexp.""" + # The regexp compilation caching is inlined in both Match and Search for + # performance reasons; factoring it out into a separate function turns out + # to be noticeably expensive. + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].match(s) + + +def ReplaceAll(pattern, rep, s): + """Replaces instances of pattern in a string with a replacement. + + The compiled regex is kept in a cache shared by Match and Search. + + Args: + pattern: regex pattern + rep: replacement text + s: search string + + Returns: + string with replacements made (or original string if no replacements) + """ + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].sub(rep, s) + + +def Search(pattern, s): + """Searches the string for the pattern, caching the compiled regexp.""" + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].search(s) + + +class _IncludeState(dict): + """Tracks line numbers for includes, and the order in which includes appear. + + As a dict, an _IncludeState object serves as a mapping between include + filename and line number on which that file was included. + + Call CheckNextIncludeOrder() once for each header in the file, passing + in the type constants defined above. Calls in an illegal order will + raise an _IncludeError with an appropriate error message. + + """ + # self._section will move monotonically through this set. If it ever + # needs to move backwards, CheckNextIncludeOrder will raise an error. + _INITIAL_SECTION = 0 + _MY_H_SECTION = 1 + _C_SECTION = 2 + _CPP_SECTION = 3 + _OTHER_H_SECTION = 4 + + _TYPE_NAMES = { + _C_SYS_HEADER: 'C system header', + _CPP_SYS_HEADER: 'C++ system header', + _LIKELY_MY_HEADER: 'header this file implements', + _POSSIBLE_MY_HEADER: 'header this file may implement', + _OTHER_HEADER: 'other header', + } + _SECTION_NAMES = { + _INITIAL_SECTION: "... nothing. (This can't be an error.)", + _MY_H_SECTION: 'a header this file implements', + _C_SECTION: 'C system header', + _CPP_SECTION: 'C++ system header', + _OTHER_H_SECTION: 'other header', + } + + def __init__(self): + dict.__init__(self) + self.ResetSection() + + def ResetSection(self): + # The name of the current section. + self._section = self._INITIAL_SECTION + # The path of last found header. + self._last_header = '' + + def SetLastHeader(self, header_path): + self._last_header = header_path + + def CanonicalizeAlphabeticalOrder(self, header_path): + """Returns a path canonicalized for alphabetical comparison. + + - replaces "-" with "_" so they both cmp the same. + - removes '-inl' since we don't require them to be after the main header. + - lowercase everything, just in case. + + Args: + header_path: Path to be canonicalized. + + Returns: + Canonicalized path. + """ + return header_path.replace('-inl.h', '.h').replace('-', '_').lower() + + def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path): + """Check if a header is in alphabetical order with the previous header. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + header_path: Canonicalized header to be checked. + + Returns: + Returns true if the header is in alphabetical order. + """ + # If previous section is different from current section, _last_header will + # be reset to empty string, so it's always less than current header. + # + # If previous line was a blank line, assume that the headers are + # intentionally sorted the way they are. + if (self._last_header > header_path and + not Match(r'^\s*$', clean_lines.elided[linenum - 1])): + return False + return True + + def CheckNextIncludeOrder(self, header_type): + """Returns a non-empty error message if the next header is out of order. + + This function also updates the internal state to be ready to check + the next include. + + Args: + header_type: One of the _XXX_HEADER constants defined above. + + Returns: + The empty string if the header is in the right order, or an + error message describing what's wrong. + + """ + error_message = ('Found %s after %s' % + (self._TYPE_NAMES[header_type], + self._SECTION_NAMES[self._section])) + + last_section = self._section + + if header_type == _C_SYS_HEADER: + if self._section <= self._C_SECTION: + self._section = self._C_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _CPP_SYS_HEADER: + if self._section <= self._CPP_SECTION: + self._section = self._CPP_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _LIKELY_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + self._section = self._OTHER_H_SECTION + elif header_type == _POSSIBLE_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + # This will always be the fallback because we're not sure + # enough that the header is associated with this file. + self._section = self._OTHER_H_SECTION + else: + assert header_type == _OTHER_HEADER + self._section = self._OTHER_H_SECTION + + if last_section != self._section: + self._last_header = '' + + return '' + + +class _CppLintState(object): + """Maintains module-wide state..""" + + def __init__(self): + self.verbose_level = 1 # global setting. + self.error_count = 0 # global count of reported errors + # filters to apply when emitting error messages + self.filters = _DEFAULT_FILTERS[:] + self.counting = 'total' # In what way are we counting errors? + self.errors_by_category = {} # string to int dict storing error counts + + # output format: + # "emacs" - format that emacs can parse (default) + # "vs7" - format that Microsoft Visual Studio 7 can parse + self.output_format = 'emacs' + + def SetOutputFormat(self, output_format): + """Sets the output format for errors.""" + self.output_format = output_format + + def SetVerboseLevel(self, level): + """Sets the module's verbosity, and returns the previous setting.""" + last_verbose_level = self.verbose_level + self.verbose_level = level + return last_verbose_level + + def SetCountingStyle(self, counting_style): + """Sets the module's counting options.""" + self.counting = counting_style + + def SetFilters(self, filters): + """Sets the error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "+whitespace/indent"). + Each filter should start with + or -; else we die. + + Raises: + ValueError: The comma-separated filters did not all start with '+' or '-'. + E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter" + """ + # Default filters always have less priority than the flag ones. + self.filters = _DEFAULT_FILTERS[:] + for filt in filters.split(','): + clean_filt = filt.strip() + if clean_filt: + self.filters.append(clean_filt) + for filt in self.filters: + if not (filt.startswith('+') or filt.startswith('-')): + raise ValueError('Every filter in --filters must start with + or -' + ' (%s does not)' % filt) + + def ResetErrorCounts(self): + """Sets the module's error statistic back to zero.""" + self.error_count = 0 + self.errors_by_category = {} + + def IncrementErrorCount(self, category): + """Bumps the module's error statistic.""" + self.error_count += 1 + if self.counting in ('toplevel', 'detailed'): + if self.counting != 'detailed': + category = category.split('/')[0] + if category not in self.errors_by_category: + self.errors_by_category[category] = 0 + self.errors_by_category[category] += 1 + + def PrintErrorCounts(self): + """Print a summary of errors by category, and the total.""" + for category, count in self.errors_by_category.iteritems(): + sys.stderr.write('Category \'%s\' errors found: %d\n' % + (category, count)) + sys.stderr.write('Total errors found: %d\n' % self.error_count) + +_cpplint_state = _CppLintState() + + +def _OutputFormat(): + """Gets the module's output format.""" + return _cpplint_state.output_format + + +def _SetOutputFormat(output_format): + """Sets the module's output format.""" + _cpplint_state.SetOutputFormat(output_format) + + +def _VerboseLevel(): + """Returns the module's verbosity setting.""" + return _cpplint_state.verbose_level + + +def _SetVerboseLevel(level): + """Sets the module's verbosity, and returns the previous setting.""" + return _cpplint_state.SetVerboseLevel(level) + + +def _SetCountingStyle(level): + """Sets the module's counting options.""" + _cpplint_state.SetCountingStyle(level) + + +def _Filters(): + """Returns the module's list of output filters, as a list.""" + return _cpplint_state.filters + + +def _SetFilters(filters): + """Sets the module's error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.SetFilters(filters) + + +class _FunctionState(object): + """Tracks current function name and the number of lines in its body.""" + + _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc. + _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER. + + def __init__(self): + self.in_a_function = False + self.lines_in_function = 0 + self.current_function = '' + + def Begin(self, function_name): + """Start analyzing function body. + + Args: + function_name: The name of the function being tracked. + """ + self.in_a_function = True + self.lines_in_function = 0 + self.current_function = function_name + + def Count(self): + """Count line in current function body.""" + if self.in_a_function: + self.lines_in_function += 1 + + def Check(self, error, filename, linenum): + """Report if too many lines in function body. + + Args: + error: The function to call with any errors found. + filename: The name of the current file. + linenum: The number of the line to check. + """ + if Match(r'T(EST|est)', self.current_function): + base_trigger = self._TEST_TRIGGER + else: + base_trigger = self._NORMAL_TRIGGER + trigger = base_trigger * 2**_VerboseLevel() + + if self.lines_in_function > trigger: + error_level = int(math.log(self.lines_in_function / base_trigger, 2)) + # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ... + if error_level > 5: + error_level = 5 + error(filename, linenum, 'readability/fn_size', error_level, + 'Small and focused functions are preferred:' + ' %s has %d non-comment lines' + ' (error triggered by exceeding %d lines).' % ( + self.current_function, self.lines_in_function, trigger)) + + def End(self): + """Stop analyzing function body.""" + self.in_a_function = False + + +class _IncludeError(Exception): + """Indicates a problem with the include order in a file.""" + pass + + +class FileInfo: + """Provides utility functions for filenames. + + FileInfo provides easy access to the components of a file's path + relative to the project root. + """ + + def __init__(self, filename): + self._filename = filename + + def FullName(self): + """Make Windows paths like Unix.""" + return os.path.abspath(self._filename).replace('\\', '/') + + def RepositoryName(self): + """FullName after removing the local path to the repository. + + If we have a real absolute path name here we can try to do something smart: + detecting the root of the checkout and truncating /path/to/checkout from + the name so that we get header guards that don't include things like + "C:\Documents and Settings\..." or "/home/username/..." in them and thus + people on different computers who have checked the source out to different + locations won't see bogus errors. + """ + fullname = self.FullName() + + if os.path.exists(fullname): + project_dir = os.path.dirname(fullname) + + if os.path.exists(os.path.join(project_dir, ".svn")): + # If there's a .svn file in the current directory, we recursively look + # up the directory tree for the top of the SVN checkout + root_dir = project_dir + one_up_dir = os.path.dirname(root_dir) + while os.path.exists(os.path.join(one_up_dir, ".svn")): + root_dir = os.path.dirname(root_dir) + one_up_dir = os.path.dirname(one_up_dir) + + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by + # searching up from the current path. + root_dir = os.path.dirname(fullname) + while (root_dir != os.path.dirname(root_dir) and + not os.path.exists(os.path.join(root_dir, ".git")) and + not os.path.exists(os.path.join(root_dir, ".hg")) and + not os.path.exists(os.path.join(root_dir, ".svn"))): + root_dir = os.path.dirname(root_dir) + + if (os.path.exists(os.path.join(root_dir, ".git")) or + os.path.exists(os.path.join(root_dir, ".hg")) or + os.path.exists(os.path.join(root_dir, ".svn"))): + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Don't know what to do; header guard warnings may be wrong... + return fullname + + def Split(self): + """Splits the file into the directory, basename, and extension. + + For 'chrome/browser/browser.cc', Split() would + return ('chrome/browser', 'browser', '.cc') + + Returns: + A tuple of (directory, basename, extension). + """ + + googlename = self.RepositoryName() + project, rest = os.path.split(googlename) + return (project,) + os.path.splitext(rest) + + def BaseName(self): + """File base name - text after the final slash, before the final period.""" + return self.Split()[1] + + def Extension(self): + """File extension - text following the final period.""" + return self.Split()[2] + + def NoExtension(self): + """File has no source file extension.""" + return '/'.join(self.Split()[0:2]) + + def IsSource(self): + """File has a source file extension.""" + return self.Extension()[1:] in ('c', 'cc', 'cpp', 'cxx') + + +def _ShouldPrintError(category, confidence, linenum): + """If confidence >= verbose, category passes filter and is not suppressed.""" + + # There are three ways we might decide not to print an error message: + # a "NOLINT(category)" comment appears in the source, + # the verbosity level isn't high enough, or the filters filter it out. + if IsErrorSuppressedByNolint(category, linenum): + return False + if confidence < _cpplint_state.verbose_level: + return False + + is_filtered = False + for one_filter in _Filters(): + if one_filter.startswith('-'): + if category.startswith(one_filter[1:]): + is_filtered = True + elif one_filter.startswith('+'): + if category.startswith(one_filter[1:]): + is_filtered = False + else: + assert False # should have been checked for in SetFilter. + if is_filtered: + return False + + return True + + +def Error(filename, linenum, category, confidence, message): + """Logs the fact we've found a lint error. + + We log where the error was found, and also our confidence in the error, + that is, how certain we are this is a legitimate style regression, and + not a misidentification or a use that's sometimes justified. + + False positives can be suppressed by the use of + "cpplint(category)" comments on the offending line. These are + parsed into _error_suppressions. + + Args: + filename: The name of the file containing the error. + linenum: The number of the line containing the error. + category: A string used to describe the "category" this bug + falls under: "whitespace", say, or "runtime". Categories + may have a hierarchy separated by slashes: "whitespace/indent". + confidence: A number from 1-5 representing a confidence score for + the error, with 5 meaning that we are certain of the problem, + and 1 meaning that it could be a legitimate construct. + message: The error message. + """ + if _ShouldPrintError(category, confidence, linenum): + _cpplint_state.IncrementErrorCount(category) + if _cpplint_state.output_format == 'vs7': + sys.stderr.write('%s(%s): %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + elif _cpplint_state.output_format == 'eclipse': + sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + else: + sys.stderr.write('%s:%s: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + + +# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard. +_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile( + r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)') +# Matches strings. Escape codes should already be removed by ESCAPES. +_RE_PATTERN_CLEANSE_LINE_DOUBLE_QUOTES = re.compile(r'"[^"]*"') +# Matches characters. Escape codes should already be removed by ESCAPES. +_RE_PATTERN_CLEANSE_LINE_SINGLE_QUOTES = re.compile(r"'.'") +# Matches multi-line C++ comments. +# This RE is a little bit more complicated than one might expect, because we +# have to take care of space removals tools so we can handle comments inside +# statements better. +# The current rule is: We only clear spaces from both sides when we're at the +# end of the line. Otherwise, we try to remove spaces from the right side, +# if this doesn't work we try on left side but only if there's a non-character +# on the right. +_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile( + r"""(\s*/\*.*\*/\s*$| + /\*.*\*/\s+| + \s+/\*.*\*/(?=\W)| + /\*.*\*/)""", re.VERBOSE) + + +def IsCppString(line): + """Does line terminate so, that the next symbol is in string constant. + + This function does not consider single-line nor multi-line comments. + + Args: + line: is a partial line of code starting from the 0..n. + + Returns: + True, if next character appended to 'line' is inside a + string constant. + """ + + line = line.replace(r'\\', 'XX') # after this, \\" does not match to \" + return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1 + + +def CleanseRawStrings(raw_lines): + """Removes C++11 raw strings from lines. + + Before: + static const char kData[] = R"( + multi-line string + )"; + + After: + static const char kData[] = "" + (replaced by blank line) + ""; + + Args: + raw_lines: list of raw lines. + + Returns: + list of lines with C++11 raw strings replaced by empty strings. + """ + + delimiter = None + lines_without_raw_strings = [] + for line in raw_lines: + if delimiter: + # Inside a raw string, look for the end + end = line.find(delimiter) + if end >= 0: + # Found the end of the string, match leading space for this + # line and resume copying the original lines, and also insert + # a "" on the last line. + leading_space = Match(r'^(\s*)\S', line) + line = leading_space.group(1) + '""' + line[end + len(delimiter):] + delimiter = None + else: + # Haven't found the end yet, append a blank line. + line = '' + + else: + # Look for beginning of a raw string. + # See 2.14.15 [lex.string] for syntax. + matched = Match(r'^(.*)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line) + if matched: + delimiter = ')' + matched.group(2) + '"' + + end = matched.group(3).find(delimiter) + if end >= 0: + # Raw string ended on same line + line = (matched.group(1) + '""' + + matched.group(3)[end + len(delimiter):]) + delimiter = None + else: + # Start of a multi-line raw string + line = matched.group(1) + '""' + + lines_without_raw_strings.append(line) + + # TODO(unknown): if delimiter is not None here, we might want to + # emit a warning for unterminated string. + return lines_without_raw_strings + + +def FindNextMultiLineCommentStart(lines, lineix): + """Find the beginning marker for a multiline comment.""" + while lineix < len(lines): + if lines[lineix].strip().startswith('/*'): + # Only return this marker if the comment goes beyond this line + if lines[lineix].strip().find('*/', 2) < 0: + return lineix + lineix += 1 + return len(lines) + + +def FindNextMultiLineCommentEnd(lines, lineix): + """We are inside a comment, find the end marker.""" + while lineix < len(lines): + if lines[lineix].strip().endswith('*/'): + return lineix + lineix += 1 + return len(lines) + + +def RemoveMultiLineCommentsFromRange(lines, begin, end): + """Clears a range of lines for multi-line comments.""" + # Having // dummy comments makes the lines non-empty, so we will not get + # unnecessary blank line warnings later in the code. + for i in range(begin, end): + lines[i] = '// dummy' + + +def RemoveMultiLineComments(filename, lines, error): + """Removes multiline (c-style) comments from lines.""" + lineix = 0 + while lineix < len(lines): + lineix_begin = FindNextMultiLineCommentStart(lines, lineix) + if lineix_begin >= len(lines): + return + lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin) + if lineix_end >= len(lines): + error(filename, lineix_begin + 1, 'readability/multiline_comment', 5, + 'Could not find end of multi-line comment') + return + RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1) + lineix = lineix_end + 1 + + +def CleanseComments(line): + """Removes //-comments and single-line C-style /* */ comments. + + Args: + line: A line of C++ source. + + Returns: + The line with single-line comments removed. + """ + commentpos = line.find('//') + if commentpos != -1 and not IsCppString(line[:commentpos]): + line = line[:commentpos].rstrip() + # get rid of /* ... */ + return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line) + + +class CleansedLines(object): + """Holds 3 copies of all lines with different preprocessing applied to them. + + 1) elided member contains lines without strings and comments, + 2) lines member contains lines without comments, and + 3) raw_lines member contains all the lines without processing. + All these three members are of , and of the same length. + """ + + def __init__(self, lines): + self.elided = [] + self.lines = [] + self.raw_lines = lines + self.num_lines = len(lines) + self.lines_without_raw_strings = CleanseRawStrings(lines) + for linenum in range(len(self.lines_without_raw_strings)): + self.lines.append(CleanseComments( + self.lines_without_raw_strings[linenum])) + elided = self._CollapseStrings(self.lines_without_raw_strings[linenum]) + self.elided.append(CleanseComments(elided)) + + def NumLines(self): + """Returns the number of lines represented.""" + return self.num_lines + + @staticmethod + def _CollapseStrings(elided): + """Collapses strings and chars on a line to simple "" or '' blocks. + + We nix strings first so we're not fooled by text like '"http://"' + + Args: + elided: The line being processed. + + Returns: + The line with collapsed strings. + """ + if not _RE_PATTERN_INCLUDE.match(elided): + # Remove escaped characters first to make quote/single quote collapsing + # basic. Things that look like escaped characters shouldn't occur + # outside of strings and chars. + elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided) + elided = _RE_PATTERN_CLEANSE_LINE_SINGLE_QUOTES.sub("''", elided) + elided = _RE_PATTERN_CLEANSE_LINE_DOUBLE_QUOTES.sub('""', elided) + return elided + + +def FindEndOfExpressionInLine(line, startpos, depth, startchar, endchar): + """Find the position just after the matching endchar. + + Args: + line: a CleansedLines line. + startpos: start searching at this position. + depth: nesting level at startpos. + startchar: expression opening character. + endchar: expression closing character. + + Returns: + On finding matching endchar: (index just after matching endchar, 0) + Otherwise: (-1, new depth at end of this line) + """ + for i in xrange(startpos, len(line)): + if line[i] == startchar: + depth += 1 + elif line[i] == endchar: + depth -= 1 + if depth == 0: + return (i + 1, 0) + return (-1, depth) + + +def CloseExpression(clean_lines, linenum, pos): + """If input points to ( or { or [ or <, finds the position that closes it. + + If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the + linenum/pos that correspond to the closing of the expression. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *past* the closing brace, or + (line, len(lines), -1) if we never find a close. Note we ignore + strings and comments when matching; and the line we return is the + 'cleansed' line at linenum. + """ + + line = clean_lines.elided[linenum] + startchar = line[pos] + if startchar not in '({[<': + return (line, clean_lines.NumLines(), -1) + if startchar == '(': endchar = ')' + if startchar == '[': endchar = ']' + if startchar == '{': endchar = '}' + if startchar == '<': endchar = '>' + + # Check first line + (end_pos, num_open) = FindEndOfExpressionInLine( + line, pos, 0, startchar, endchar) + if end_pos > -1: + return (line, linenum, end_pos) + + # Continue scanning forward + while linenum < clean_lines.NumLines() - 1: + linenum += 1 + line = clean_lines.elided[linenum] + (end_pos, num_open) = FindEndOfExpressionInLine( + line, 0, num_open, startchar, endchar) + if end_pos > -1: + return (line, linenum, end_pos) + + # Did not find endchar before end of file, give up + return (line, clean_lines.NumLines(), -1) + + +def FindStartOfExpressionInLine(line, endpos, depth, startchar, endchar): + """Find position at the matching startchar. + + This is almost the reverse of FindEndOfExpressionInLine, but note + that the input position and returned position differs by 1. + + Args: + line: a CleansedLines line. + endpos: start searching at this position. + depth: nesting level at endpos. + startchar: expression opening character. + endchar: expression closing character. + + Returns: + On finding matching startchar: (index at matching startchar, 0) + Otherwise: (-1, new depth at beginning of this line) + """ + for i in xrange(endpos, -1, -1): + if line[i] == endchar: + depth += 1 + elif line[i] == startchar: + depth -= 1 + if depth == 0: + return (i, 0) + return (-1, depth) + + +def ReverseCloseExpression(clean_lines, linenum, pos): + """If input points to ) or } or ] or >, finds the position that opens it. + + If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the + linenum/pos that correspond to the opening of the expression. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *at* the opening brace, or + (line, 0, -1) if we never find the matching opening brace. Note + we ignore strings and comments when matching; and the line we + return is the 'cleansed' line at linenum. + """ + line = clean_lines.elided[linenum] + endchar = line[pos] + if endchar not in ')}]>': + return (line, 0, -1) + if endchar == ')': startchar = '(' + if endchar == ']': startchar = '[' + if endchar == '}': startchar = '{' + if endchar == '>': startchar = '<' + + # Check last line + (start_pos, num_open) = FindStartOfExpressionInLine( + line, pos, 0, startchar, endchar) + if start_pos > -1: + return (line, linenum, start_pos) + + # Continue scanning backward + while linenum > 0: + linenum -= 1 + line = clean_lines.elided[linenum] + (start_pos, num_open) = FindStartOfExpressionInLine( + line, len(line) - 1, num_open, startchar, endchar) + if start_pos > -1: + return (line, linenum, start_pos) + + # Did not find startchar before beginning of file, give up + return (line, 0, -1) + + +def CheckForCopyright(filename, lines, error): + """Logs an error if a Copyright message appears at the top of the file.""" + + # We'll check up to line 10. Don't forget there's a + # dummy line at the front. + for line in xrange(1, min(len(lines), 11)): + if _RE_COPYRIGHT.search(lines[line], re.I): + error(filename, 0, 'legal/copyright', 5, + 'Copyright message found. ' + 'You should not include a copyright line.') + + +def GetHeaderGuardCPPVariable(filename): + """Returns the CPP variable that should be used as a header guard. + + Args: + filename: The name of a C++ header file. + + Returns: + The CPP variable that should be used as a header guard in the + named file. + + """ + + # Restores original filename in case that cpplint is invoked from Emacs's + # flymake. + filename = re.sub(r'_flymake\.h$', '.h', filename) + filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename) + + fileinfo = FileInfo(filename) + file_path_from_root = fileinfo.RepositoryName() + if _root: + file_path_from_root = re.sub('^' + _root + os.sep, '', file_path_from_root) + return re.sub(r'[-./\s]', '_', file_path_from_root).upper() + '_' + + +def CheckForHeaderGuard(filename, lines, error): + """Checks that the file contains a header guard. + + Logs an error if no #ifndef header guard is present. For other + headers, checks that the full pathname is used. + + Args: + filename: The name of the C++ header file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + cppvar = GetHeaderGuardCPPVariable(filename) + + ifndef = None + ifndef_linenum = 0 + define = None + endif = None + endif_linenum = 0 + for linenum, line in enumerate(lines): + linesplit = line.split() + if len(linesplit) >= 2: + # find the first occurrence of #ifndef and #define, save arg + if not ifndef and linesplit[0] == '#ifndef': + # set ifndef to the header guard presented on the #ifndef line. + ifndef = linesplit[1] + ifndef_linenum = linenum + if not define and linesplit[0] == '#define': + define = linesplit[1] + # find the last occurrence of #endif, save entire line + if line.startswith('#endif'): + endif = line + endif_linenum = linenum + + if not ifndef: + error(filename, 0, 'build/header_guard', 5, + 'No #ifndef header guard found, suggested CPP variable is: %s' % + cppvar) + return + + if not define: + error(filename, 0, 'build/header_guard', 5, + 'No #define header guard found, suggested CPP variable is: %s' % + cppvar) + return + + # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__ + # for backward compatibility. + if ifndef != cppvar: + error_level = 0 + if ifndef != cppvar + '_': + error_level = 5 + + ParseNolintSuppressions(filename, lines[ifndef_linenum], ifndef_linenum, + error) + error(filename, ifndef_linenum, 'build/header_guard', error_level, + '#ifndef header guard has wrong style, please use: %s' % cppvar) + + if define != ifndef: + error(filename, 0, 'build/header_guard', 5, + '#ifndef and #define don\'t match, suggested CPP variable is: %s' % + cppvar) + return + + if endif != ('#endif // %s' % cppvar): + error_level = 0 + if endif != ('#endif // %s' % (cppvar + '_')): + error_level = 5 + + ParseNolintSuppressions(filename, lines[endif_linenum], endif_linenum, + error) + error(filename, endif_linenum, 'build/header_guard', error_level, + '#endif line should be "#endif // %s"' % cppvar) + + +def CheckForBadCharacters(filename, lines, error): + """Logs an error for each line containing bad characters. + + Two kinds of bad characters: + + 1. Unicode replacement characters: These indicate that either the file + contained invalid UTF-8 (likely) or Unicode replacement characters (which + it shouldn't). Note that it's possible for this to throw off line + numbering if the invalid UTF-8 occurred adjacent to a newline. + + 2. NUL bytes. These are problematic for some tools. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + for linenum, line in enumerate(lines): + if u'\ufffd' in line: + error(filename, linenum, 'readability/utf8', 5, + 'Line contains invalid UTF-8 (or Unicode replacement character).') + if '\0' in line: + error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.') + + +def CheckForNewlineAtEOF(filename, lines, error): + """Logs an error if there is no newline char at the end of the file. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + # The array lines() was created by adding two newlines to the + # original file (go figure), then splitting on \n. + # To verify that the file ends in \n, we just have to make sure the + # last-but-two element of lines() exists and is empty. + if len(lines) < 3 or lines[-2]: + error(filename, len(lines) - 2, 'whitespace/ending_newline', 5, + 'Could not find a newline character at the end of the file.') + + +def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error): + """Logs an error if we see /* ... */ or "..." that extend past one line. + + /* ... */ comments are legit inside macros, for one line. + Otherwise, we prefer // comments, so it's ok to warn about the + other. Likewise, it's ok for strings to extend across multiple + lines, as long as a line continuation character (backslash) + terminates each line. Although not currently prohibited by the C++ + style guide, it's ugly and unnecessary. We don't do well with either + in this lint program, so we warn about both. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remove all \\ (escaped backslashes) from the line. They are OK, and the + # second (escaped) slash may trigger later \" detection erroneously. + line = line.replace('\\\\', '') + + if line.count('/*') > line.count('*/'): + error(filename, linenum, 'readability/multiline_comment', 5, + 'Complex multi-line /*...*/-style comment found. ' + 'Lint may give bogus warnings. ' + 'Consider replacing these with //-style comments, ' + 'with #if 0...#endif, ' + 'or with more clearly structured multi-line comments.') + + if (line.count('"') - line.count('\\"')) % 2: + error(filename, linenum, 'readability/multiline_string', 5, + 'Multi-line string ("...") found. This lint script doesn\'t ' + 'do well with such strings, and may give bogus warnings. ' + 'Use C++11 raw strings or concatenation instead.') + + +caffe_alt_function_list = ( + ('memset', ['caffe_set', 'caffe_memset']), + ('cudaMemset', ['caffe_gpu_set', 'caffe_gpu_memset']), + ('memcpy', ['caffe_copy', 'caffe_memcpy']), + ('cudaMemcpy', ['caffe_copy', 'caffe_gpu_memcpy']), + ) + + +def CheckCaffeAlternatives(filename, clean_lines, linenum, error): + """Checks for C(++) functions for which a Caffe substitute should be used. + + For certain native C functions (memset, memcpy), there is a Caffe alternative + which should be used instead. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for function, alts in caffe_alt_function_list: + ix = line.find(function + '(') + if ix >= 0 and (ix == 0 or (not line[ix - 1].isalnum() and + line[ix - 1] not in ('_', '.', '>'))): + disp_alts = ['%s(...)' % alt for alt in alts] + error(filename, linenum, 'caffe/alt_fn', 2, + 'Use Caffe function %s instead of %s(...).' % + (' or '.join(disp_alts), function)) + + +def CheckCaffeDataLayerSetUp(filename, clean_lines, linenum, error): + """Except the base classes, Caffe DataLayer should define DataLayerSetUp + instead of LayerSetUp. + + The base DataLayers define common SetUp steps, the subclasses should + not override them. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + ix = line.find('DataLayer::LayerSetUp') + if ix >= 0 and ( + line.find('void DataLayer::LayerSetUp') != -1 or + line.find('void ImageDataLayer::LayerSetUp') != -1 or + line.find('void MemoryDataLayer::LayerSetUp') != -1 or + line.find('void WindowDataLayer::LayerSetUp') != -1): + error(filename, linenum, 'caffe/data_layer_setup', 2, + 'Except the base classes, Caffe DataLayer should define' + + ' DataLayerSetUp instead of LayerSetUp. The base DataLayers' + + ' define common SetUp steps, the subclasses should' + + ' not override them.') + ix = line.find('DataLayer::DataLayerSetUp') + if ix >= 0 and ( + line.find('void Base') == -1 and + line.find('void DataLayer::DataLayerSetUp') == -1 and + line.find('void ImageDataLayer::DataLayerSetUp') == -1 and + line.find('void MemoryDataLayer::DataLayerSetUp') == -1 and + line.find('void WindowDataLayer::DataLayerSetUp') == -1): + error(filename, linenum, 'caffe/data_layer_setup', 2, + 'Except the base classes, Caffe DataLayer should define' + + ' DataLayerSetUp instead of LayerSetUp. The base DataLayers' + + ' define common SetUp steps, the subclasses should' + + ' not override them.') + + +c_random_function_list = ( + 'rand(', + 'rand_r(', + 'random(', + ) + +def CheckCaffeRandom(filename, clean_lines, linenum, error): + """Checks for calls to C random functions (rand, rand_r, random, ...). + + Caffe code should (almost) always use the caffe_rng_* functions rather + than these, as the internal state of these C functions is independent of the + native Caffe RNG system which should produce deterministic results for a + fixed Caffe seed set using Caffe::set_random_seed(...). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for function in c_random_function_list: + ix = line.find(function) + # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison + if ix >= 0 and (ix == 0 or (not line[ix - 1].isalnum() and + line[ix - 1] not in ('_', '.', '>'))): + error(filename, linenum, 'caffe/random_fn', 2, + 'Use caffe_rng_rand() (or other caffe_rng_* function) instead of ' + + function + + ') to ensure results are deterministic for a fixed Caffe seed.') + + +threading_list = ( + ('asctime(', 'asctime_r('), + ('ctime(', 'ctime_r('), + ('getgrgid(', 'getgrgid_r('), + ('getgrnam(', 'getgrnam_r('), + ('getlogin(', 'getlogin_r('), + ('getpwnam(', 'getpwnam_r('), + ('getpwuid(', 'getpwuid_r('), + ('gmtime(', 'gmtime_r('), + ('localtime(', 'localtime_r('), + ('strtok(', 'strtok_r('), + ('ttyname(', 'ttyname_r('), + ) + + +def CheckPosixThreading(filename, clean_lines, linenum, error): + """Checks for calls to thread-unsafe functions. + + Much code has been originally written without consideration of + multi-threading. Also, engineers are relying on their old experience; + they have learned posix before threading extensions were added. These + tests guide the engineers to use thread-safe functions (when using + posix directly). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for single_thread_function, multithread_safe_function in threading_list: + ix = line.find(single_thread_function) + # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison + if ix >= 0 and (ix == 0 or (not line[ix - 1].isalnum() and + line[ix - 1] not in ('_', '.', '>'))): + error(filename, linenum, 'runtime/threadsafe_fn', 2, + 'Consider using ' + multithread_safe_function + + '...) instead of ' + single_thread_function + + '...) for improved thread safety.') + + +def CheckVlogArguments(filename, clean_lines, linenum, error): + """Checks that VLOG() is only used for defining a logging level. + + For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and + VLOG(FATAL) are not. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line): + error(filename, linenum, 'runtime/vlog', 5, + 'VLOG() should be used with numeric verbosity level. ' + 'Use LOG() if you want symbolic severity levels.') + + +# Matches invalid increment: *count++, which moves pointer instead of +# incrementing a value. +_RE_PATTERN_INVALID_INCREMENT = re.compile( + r'^\s*\*\w+(\+\+|--);') + + +def CheckInvalidIncrement(filename, clean_lines, linenum, error): + """Checks for invalid increment *count++. + + For example following function: + void increment_counter(int* count) { + *count++; + } + is invalid, because it effectively does count++, moving pointer, and should + be replaced with ++*count, (*count)++ or *count += 1. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if _RE_PATTERN_INVALID_INCREMENT.match(line): + error(filename, linenum, 'runtime/invalid_increment', 5, + 'Changing pointer instead of value (or unused value of operator*).') + + +class _BlockInfo(object): + """Stores information about a generic block of code.""" + + def __init__(self, seen_open_brace): + self.seen_open_brace = seen_open_brace + self.open_parentheses = 0 + self.inline_asm = _NO_ASM + + def CheckBegin(self, filename, clean_lines, linenum, error): + """Run checks that applies to text up to the opening brace. + + This is mostly for checking the text after the class identifier + and the "{", usually where the base class is specified. For other + blocks, there isn't much to check, so we always pass. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Run checks that applies to text after the closing brace. + + This is mostly used for checking end of namespace comments. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + +class _ClassInfo(_BlockInfo): + """Stores information about a class.""" + + def __init__(self, name, class_or_struct, clean_lines, linenum): + _BlockInfo.__init__(self, False) + self.name = name + self.starting_linenum = linenum + self.is_derived = False + if class_or_struct == 'struct': + self.access = 'public' + self.is_struct = True + else: + self.access = 'private' + self.is_struct = False + + # Remember initial indentation level for this class. Using raw_lines here + # instead of elided to account for leading comments. + initial_indent = Match(r'^( *)\S', clean_lines.raw_lines[linenum]) + if initial_indent: + self.class_indent = len(initial_indent.group(1)) + else: + self.class_indent = 0 + + # Try to find the end of the class. This will be confused by things like: + # class A { + # } *x = { ... + # + # But it's still good enough for CheckSectionSpacing. + self.last_line = 0 + depth = 0 + for i in range(linenum, clean_lines.NumLines()): + line = clean_lines.elided[i] + depth += line.count('{') - line.count('}') + if not depth: + self.last_line = i + break + + def CheckBegin(self, filename, clean_lines, linenum, error): + # Look for a bare ':' + if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]): + self.is_derived = True + + def CheckEnd(self, filename, clean_lines, linenum, error): + # Check that closing brace is aligned with beginning of the class. + # Only do this if the closing brace is indented by only whitespaces. + # This means we will not check single-line class definitions. + indent = Match(r'^( *)\}', clean_lines.elided[linenum]) + if indent and len(indent.group(1)) != self.class_indent: + if self.is_struct: + parent = 'struct ' + self.name + else: + parent = 'class ' + self.name + error(filename, linenum, 'whitespace/indent', 3, + 'Closing brace should be aligned with beginning of %s' % parent) + + +class _NamespaceInfo(_BlockInfo): + """Stores information about a namespace.""" + + def __init__(self, name, linenum): + _BlockInfo.__init__(self, False) + self.name = name or '' + self.starting_linenum = linenum + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Check end of namespace comments.""" + line = clean_lines.raw_lines[linenum] + + # Check how many lines is enclosed in this namespace. Don't issue + # warning for missing namespace comments if there aren't enough + # lines. However, do apply checks if there is already an end of + # namespace comment and it's incorrect. + # + # TODO(unknown): We always want to check end of namespace comments + # if a namespace is large, but sometimes we also want to apply the + # check if a short namespace contained nontrivial things (something + # other than forward declarations). There is currently no logic on + # deciding what these nontrivial things are, so this check is + # triggered by namespace size only, which works most of the time. + if (linenum - self.starting_linenum < 10 + and not Match(r'};*\s*(//|/\*).*\bnamespace\b', line)): + return + + # Look for matching comment at end of namespace. + # + # Note that we accept C style "/* */" comments for terminating + # namespaces, so that code that terminate namespaces inside + # preprocessor macros can be cpplint clean. + # + # We also accept stuff like "// end of namespace ." with the + # period at the end. + # + # Besides these, we don't accept anything else, otherwise we might + # get false negatives when existing comment is a substring of the + # expected namespace. + if self.name: + # Named namespace + if not Match((r'};*\s*(//|/\*).*\bnamespace\s+' + re.escape(self.name) + + r'[\*/\.\\\s]*$'), + line): + error(filename, linenum, 'readability/namespace', 5, + 'Namespace should be terminated with "// namespace %s"' % + self.name) + else: + # Anonymous namespace + if not Match(r'};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line): + error(filename, linenum, 'readability/namespace', 5, + 'Namespace should be terminated with "// namespace"') + + +class _PreprocessorInfo(object): + """Stores checkpoints of nesting stacks when #if/#else is seen.""" + + def __init__(self, stack_before_if): + # The entire nesting stack before #if + self.stack_before_if = stack_before_if + + # The entire nesting stack up to #else + self.stack_before_else = [] + + # Whether we have already seen #else or #elif + self.seen_else = False + + +class _NestingState(object): + """Holds states related to parsing braces.""" + + def __init__(self): + # Stack for tracking all braces. An object is pushed whenever we + # see a "{", and popped when we see a "}". Only 3 types of + # objects are possible: + # - _ClassInfo: a class or struct. + # - _NamespaceInfo: a namespace. + # - _BlockInfo: some other type of block. + self.stack = [] + + # Stack of _PreprocessorInfo objects. + self.pp_stack = [] + + def SeenOpenBrace(self): + """Check if we have seen the opening brace for the innermost block. + + Returns: + True if we have seen the opening brace, False if the innermost + block is still expecting an opening brace. + """ + return (not self.stack) or self.stack[-1].seen_open_brace + + def InNamespaceBody(self): + """Check if we are currently one level inside a namespace body. + + Returns: + True if top of the stack is a namespace block, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _NamespaceInfo) + + def UpdatePreprocessor(self, line): + """Update preprocessor stack. + + We need to handle preprocessors due to classes like this: + #ifdef SWIG + struct ResultDetailsPageElementExtensionPoint { + #else + struct ResultDetailsPageElementExtensionPoint : public Extension { + #endif + + We make the following assumptions (good enough for most files): + - Preprocessor condition evaluates to true from #if up to first + #else/#elif/#endif. + + - Preprocessor condition evaluates to false from #else/#elif up + to #endif. We still perform lint checks on these lines, but + these do not affect nesting stack. + + Args: + line: current line to check. + """ + if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line): + # Beginning of #if block, save the nesting stack here. The saved + # stack will allow us to restore the parsing state in the #else case. + self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack))) + elif Match(r'^\s*#\s*(else|elif)\b', line): + # Beginning of #else block + if self.pp_stack: + if not self.pp_stack[-1].seen_else: + # This is the first #else or #elif block. Remember the + # whole nesting stack up to this point. This is what we + # keep after the #endif. + self.pp_stack[-1].seen_else = True + self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack) + + # Restore the stack to how it was before the #if + self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if) + else: + # TODO(unknown): unexpected #else, issue warning? + pass + elif Match(r'^\s*#\s*endif\b', line): + # End of #if or #else blocks. + if self.pp_stack: + # If we saw an #else, we will need to restore the nesting + # stack to its former state before the #else, otherwise we + # will just continue from where we left off. + if self.pp_stack[-1].seen_else: + # Here we can just use a shallow copy since we are the last + # reference to it. + self.stack = self.pp_stack[-1].stack_before_else + # Drop the corresponding #if + self.pp_stack.pop() + else: + # TODO(unknown): unexpected #endif, issue warning? + pass + + def Update(self, filename, clean_lines, linenum, error): + """Update nesting state with current line. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Update pp_stack first + self.UpdatePreprocessor(line) + + # Count parentheses. This is to avoid adding struct arguments to + # the nesting stack. + if self.stack: + inner_block = self.stack[-1] + depth_change = line.count('(') - line.count(')') + inner_block.open_parentheses += depth_change + + # Also check if we are starting or ending an inline assembly block. + if inner_block.inline_asm in (_NO_ASM, _END_ASM): + if (depth_change != 0 and + inner_block.open_parentheses == 1 and + _MATCH_ASM.match(line)): + # Enter assembly block + inner_block.inline_asm = _INSIDE_ASM + else: + # Not entering assembly block. If previous line was _END_ASM, + # we will now shift to _NO_ASM state. + inner_block.inline_asm = _NO_ASM + elif (inner_block.inline_asm == _INSIDE_ASM and + inner_block.open_parentheses == 0): + # Exit assembly block + inner_block.inline_asm = _END_ASM + + # Consume namespace declaration at the beginning of the line. Do + # this in a loop so that we catch same line declarations like this: + # namespace proto2 { namespace bridge { class MessageSet; } } + while True: + # Match start of namespace. The "\b\s*" below catches namespace + # declarations even if it weren't followed by a whitespace, this + # is so that we don't confuse our namespace checker. The + # missing spaces will be flagged by CheckSpacing. + namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line) + if not namespace_decl_match: + break + + new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum) + self.stack.append(new_namespace) + + line = namespace_decl_match.group(2) + if line.find('{') != -1: + new_namespace.seen_open_brace = True + line = line[line.find('{') + 1:] + + # Look for a class declaration in whatever is left of the line + # after parsing namespaces. The regexp accounts for decorated classes + # such as in: + # class LOCKABLE API Object { + # }; + # + # Templates with class arguments may confuse the parser, for example: + # template , + # class Vector = vector > + # class HeapQueue { + # + # Because this parser has no nesting state about templates, by the + # time it saw "class Comparator", it may think that it's a new class. + # Nested templates have a similar problem: + # template < + # typename ExportedType, + # typename TupleType, + # template class ImplTemplate> + # + # To avoid these cases, we ignore classes that are followed by '=' or '>' + class_decl_match = Match( + r'\s*(template\s*<[\w\s<>,:]*>\s*)?' + r'(class|struct)\s+([A-Z_]+\s+)*(\w+(?:::\w+)*)' + r'(([^=>]|<[^<>]*>|<[^<>]*<[^<>]*>\s*>)*)$', line) + if (class_decl_match and + (not self.stack or self.stack[-1].open_parentheses == 0)): + self.stack.append(_ClassInfo( + class_decl_match.group(4), class_decl_match.group(2), + clean_lines, linenum)) + line = class_decl_match.group(5) + + # If we have not yet seen the opening brace for the innermost block, + # run checks here. + if not self.SeenOpenBrace(): + self.stack[-1].CheckBegin(filename, clean_lines, linenum, error) + + # Update access control if we are inside a class/struct + if self.stack and isinstance(self.stack[-1], _ClassInfo): + classinfo = self.stack[-1] + access_match = Match( + r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?' + r':(?:[^:]|$)', + line) + if access_match: + classinfo.access = access_match.group(2) + + # Check that access keywords are indented +1 space. Skip this + # check if the keywords are not preceded by whitespaces. + indent = access_match.group(1) + if (len(indent) != classinfo.class_indent + 1 and + Match(r'^\s*$', indent)): + if classinfo.is_struct: + parent = 'struct ' + classinfo.name + else: + parent = 'class ' + classinfo.name + slots = '' + if access_match.group(3): + slots = access_match.group(3) + error(filename, linenum, 'whitespace/indent', 3, + '%s%s: should be indented +1 space inside %s' % ( + access_match.group(2), slots, parent)) + + # Consume braces or semicolons from what's left of the line + while True: + # Match first brace, semicolon, or closed parenthesis. + matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line) + if not matched: + break + + token = matched.group(1) + if token == '{': + # If namespace or class hasn't seen a opening brace yet, mark + # namespace/class head as complete. Push a new block onto the + # stack otherwise. + if not self.SeenOpenBrace(): + self.stack[-1].seen_open_brace = True + else: + self.stack.append(_BlockInfo(True)) + if _MATCH_ASM.match(line): + self.stack[-1].inline_asm = _BLOCK_ASM + elif token == ';' or token == ')': + # If we haven't seen an opening brace yet, but we already saw + # a semicolon, this is probably a forward declaration. Pop + # the stack for these. + # + # Similarly, if we haven't seen an opening brace yet, but we + # already saw a closing parenthesis, then these are probably + # function arguments with extra "class" or "struct" keywords. + # Also pop these stack for these. + if not self.SeenOpenBrace(): + self.stack.pop() + else: # token == '}' + # Perform end of block checks and pop the stack. + if self.stack: + self.stack[-1].CheckEnd(filename, clean_lines, linenum, error) + self.stack.pop() + line = matched.group(2) + + def InnermostClass(self): + """Get class info on the top of the stack. + + Returns: + A _ClassInfo object if we are inside a class, or None otherwise. + """ + for i in range(len(self.stack), 0, -1): + classinfo = self.stack[i - 1] + if isinstance(classinfo, _ClassInfo): + return classinfo + return None + + def CheckCompletedBlocks(self, filename, error): + """Checks that all classes and namespaces have been completely parsed. + + Call this when all lines in a file have been processed. + Args: + filename: The name of the current file. + error: The function to call with any errors found. + """ + # Note: This test can result in false positives if #ifdef constructs + # get in the way of brace matching. See the testBuildClass test in + # cpplint_unittest.py for an example of this. + for obj in self.stack: + if isinstance(obj, _ClassInfo): + error(filename, obj.starting_linenum, 'build/class', 5, + 'Failed to find complete declaration of class %s' % + obj.name) + elif isinstance(obj, _NamespaceInfo): + error(filename, obj.starting_linenum, 'build/namespaces', 5, + 'Failed to find complete declaration of namespace %s' % + obj.name) + + +def CheckForNonStandardConstructs(filename, clean_lines, linenum, + nesting_state, error): + r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2. + + Complain about several constructs which gcc-2 accepts, but which are + not standard C++. Warning about these in lint is one way to ease the + transition to new compilers. + - put storage class first (e.g. "static const" instead of "const static"). + - "%lld" instead of %qd" in printf-type functions. + - "%1$d" is non-standard in printf-type functions. + - "\%" is an undefined character escape sequence. + - text after #endif is not allowed. + - invalid inner-style forward declaration. + - >? and ?= and )\?=?\s*(\w+|[+-]?\d+)(\.\d*)?', + line): + error(filename, linenum, 'build/deprecated', 3, + '>? and ))?' + # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;' + error(filename, linenum, 'runtime/member_string_references', 2, + 'const string& members are dangerous. It is much better to use ' + 'alternatives, such as pointers or simple constants.') + + # Everything else in this function operates on class declarations. + # Return early if the top of the nesting stack is not a class, or if + # the class head is not completed yet. + classinfo = nesting_state.InnermostClass() + if not classinfo or not classinfo.seen_open_brace: + return + + # The class may have been declared with namespace or classname qualifiers. + # The constructor and destructor will not have those qualifiers. + base_classname = classinfo.name.split('::')[-1] + + # Look for single-argument constructors that aren't marked explicit. + # Technically a valid construct, but against style. + args = Match(r'\s+(?:inline\s+)?%s\s*\(([^,()]+)\)' + % re.escape(base_classname), + line) + if (args and + args.group(1) != 'void' and + not Match(r'(const\s+)?%s(\s+const)?\s*(?:<\w+>\s*)?&' + % re.escape(base_classname), args.group(1).strip())): + error(filename, linenum, 'runtime/explicit', 5, + 'Single-argument constructors should be marked explicit.') + + +def CheckSpacingForFunctionCall(filename, line, linenum, error): + """Checks for the correctness of various spacing around function calls. + + Args: + filename: The name of the current file. + line: The text of the line to check. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Since function calls often occur inside if/for/while/switch + # expressions - which have their own, more liberal conventions - we + # first see if we should be looking inside such an expression for a + # function call, to which we can apply more strict standards. + fncall = line # if there's no control flow construct, look at whole line + for pattern in (r'\bif\s*\((.*)\)\s*{', + r'\bfor\s*\((.*)\)\s*{', + r'\bwhile\s*\((.*)\)\s*[{;]', + r'\bswitch\s*\((.*)\)\s*{'): + match = Search(pattern, line) + if match: + fncall = match.group(1) # look inside the parens for function calls + break + + # Except in if/for/while/switch, there should never be space + # immediately inside parens (eg "f( 3, 4 )"). We make an exception + # for nested parens ( (a+b) + c ). Likewise, there should never be + # a space before a ( when it's a function argument. I assume it's a + # function argument when the char before the whitespace is legal in + # a function name (alnum + _) and we're not starting a macro. Also ignore + # pointers and references to arrays and functions coz they're too tricky: + # we use a very simple way to recognize these: + # " (something)(maybe-something)" or + # " (something)(maybe-something," or + # " (something)[something]" + # Note that we assume the contents of [] to be short enough that + # they'll never need to wrap. + if ( # Ignore control structures. + not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b', + fncall) and + # Ignore pointers/references to functions. + not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and + # Ignore pointers/references to arrays. + not Search(r' \([^)]+\)\[[^\]]+\]', fncall)): + if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space after ( in function call') + elif Search(r'\(\s+(?!(\s*\\)|\()', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space after (') + if (Search(r'\w\s+\(', fncall) and + not Search(r'#\s*define|typedef', fncall) and + not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall)): + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space before ( in function call') + # If the ) is followed only by a newline or a { + newline, assume it's + # part of a control statement (if/while/etc), and don't complain + if Search(r'[^)]\s+\)\s*[^{\s]', fncall): + # If the closing parenthesis is preceded by only whitespaces, + # try to give a more descriptive error message. + if Search(r'^\s+\)', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Closing ) should be moved to the previous line') + else: + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space before )') + + +def IsBlankLine(line): + """Returns true if the given line is blank. + + We consider a line to be blank if the line is empty or consists of + only white spaces. + + Args: + line: A line of a string. + + Returns: + True, if the given line is blank. + """ + return not line or line.isspace() + + +def CheckForFunctionLengths(filename, clean_lines, linenum, + function_state, error): + """Reports for long function bodies. + + For an overview why this is done, see: + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions + + Uses a simplistic algorithm assuming other style guidelines + (especially spacing) are followed. + Only checks unindented functions, so class members are unchecked. + Trivial bodies are unchecked, so constructors with huge initializer lists + may be missed. + Blank/comment lines are not counted so as to avoid encouraging the removal + of vertical space and comments just to get through a lint check. + NOLINT *on the last line of a function* disables this check. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + function_state: Current function name and lines in body so far. + error: The function to call with any errors found. + """ + lines = clean_lines.lines + line = lines[linenum] + raw = clean_lines.raw_lines + raw_line = raw[linenum] + joined_line = '' + + starting_func = False + regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ... + match_result = Match(regexp, line) + if match_result: + # If the name is all caps and underscores, figure it's a macro and + # ignore it, unless it's TEST or TEST_F. + function_name = match_result.group(1).split()[-1] + if function_name == 'TEST' or function_name == 'TEST_F' or ( + not Match(r'[A-Z_]+$', function_name)): + starting_func = True + + if starting_func: + body_found = False + for start_linenum in xrange(linenum, clean_lines.NumLines()): + start_line = lines[start_linenum] + joined_line += ' ' + start_line.lstrip() + if Search(r'(;|})', start_line): # Declarations and trivial functions + body_found = True + break # ... ignore + elif Search(r'{', start_line): + body_found = True + function = Search(r'((\w|:)*)\(', line).group(1) + if Match(r'TEST', function): # Handle TEST... macros + parameter_regexp = Search(r'(\(.*\))', joined_line) + if parameter_regexp: # Ignore bad syntax + function += parameter_regexp.group(1) + else: + function += '()' + function_state.Begin(function) + break + if not body_found: + # No body for the function (or evidence of a non-function) was found. + error(filename, linenum, 'readability/fn_size', 5, + 'Lint failed to find start of function body.') + elif Match(r'^\}\s*$', line): # function end + function_state.Check(error, filename, linenum) + function_state.End() + elif not Match(r'^\s*$', line): + function_state.Count() # Count non-blank/non-comment lines. + + +_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?') + + +def CheckComment(comment, filename, linenum, error): + """Checks for common mistakes in TODO comments. + + Args: + comment: The text of the comment from the line in question. + filename: The name of the current file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + match = _RE_PATTERN_TODO.match(comment) + if match: + # One whitespace is correct; zero whitespace is handled elsewhere. + leading_whitespace = match.group(1) + if len(leading_whitespace) > 1: + error(filename, linenum, 'whitespace/todo', 2, + 'Too many spaces before TODO') + + username = match.group(2) + if not username: + error(filename, linenum, 'readability/todo', 2, + 'Missing username in TODO; it should look like ' + '"// TODO(my_username): Stuff."') + + middle_whitespace = match.group(3) + # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison + if middle_whitespace != ' ' and middle_whitespace != '': + error(filename, linenum, 'whitespace/todo', 2, + 'TODO(my_username) should be followed by a space') + +def CheckAccess(filename, clean_lines, linenum, nesting_state, error): + """Checks for improper use of DISALLOW* macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] # get rid of comments and strings + + matched = Match((r'\s*(DISALLOW_COPY_AND_ASSIGN|' + r'DISALLOW_EVIL_CONSTRUCTORS|' + r'DISALLOW_IMPLICIT_CONSTRUCTORS)'), line) + if not matched: + return + if nesting_state.stack and isinstance(nesting_state.stack[-1], _ClassInfo): + if nesting_state.stack[-1].access != 'private': + error(filename, linenum, 'readability/constructors', 3, + '%s must be in the private: section' % matched.group(1)) + + else: + # Found DISALLOW* macro outside a class declaration, or perhaps it + # was used inside a function when it should have been part of the + # class declaration. We could issue a warning here, but it + # probably resulted in a compiler error already. + pass + + +def FindNextMatchingAngleBracket(clean_lines, linenum, init_suffix): + """Find the corresponding > to close a template. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: Current line number. + init_suffix: Remainder of the current line after the initial <. + + Returns: + True if a matching bracket exists. + """ + line = init_suffix + nesting_stack = ['<'] + while True: + # Find the next operator that can tell us whether < is used as an + # opening bracket or as a less-than operator. We only want to + # warn on the latter case. + # + # We could also check all other operators and terminate the search + # early, e.g. if we got something like this "a(),;\[\]]*([<>(),;\[\]])(.*)$', line) + if match: + # Found an operator, update nesting stack + operator = match.group(1) + line = match.group(2) + + if nesting_stack[-1] == '<': + # Expecting closing angle bracket + if operator in ('<', '(', '['): + nesting_stack.append(operator) + elif operator == '>': + nesting_stack.pop() + if not nesting_stack: + # Found matching angle bracket + return True + elif operator == ',': + # Got a comma after a bracket, this is most likely a template + # argument. We have not seen a closing angle bracket yet, but + # it's probably a few lines later if we look for it, so just + # return early here. + return True + else: + # Got some other operator. + return False + + else: + # Expecting closing parenthesis or closing bracket + if operator in ('<', '(', '['): + nesting_stack.append(operator) + elif operator in (')', ']'): + # We don't bother checking for matching () or []. If we got + # something like (] or [), it would have been a syntax error. + nesting_stack.pop() + + else: + # Scan the next line + linenum += 1 + if linenum >= len(clean_lines.elided): + break + line = clean_lines.elided[linenum] + + # Exhausted all remaining lines and still no matching angle bracket. + # Most likely the input was incomplete, otherwise we should have + # seen a semicolon and returned early. + return True + + +def FindPreviousMatchingAngleBracket(clean_lines, linenum, init_prefix): + """Find the corresponding < that started a template. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: Current line number. + init_prefix: Part of the current line before the initial >. + + Returns: + True if a matching bracket exists. + """ + line = init_prefix + nesting_stack = ['>'] + while True: + # Find the previous operator + match = Search(r'^(.*)([<>(),;\[\]])[^<>(),;\[\]]*$', line) + if match: + # Found an operator, update nesting stack + operator = match.group(2) + line = match.group(1) + + if nesting_stack[-1] == '>': + # Expecting opening angle bracket + if operator in ('>', ')', ']'): + nesting_stack.append(operator) + elif operator == '<': + nesting_stack.pop() + if not nesting_stack: + # Found matching angle bracket + return True + elif operator == ',': + # Got a comma before a bracket, this is most likely a + # template argument. The opening angle bracket is probably + # there if we look for it, so just return early here. + return True + else: + # Got some other operator. + return False + + else: + # Expecting opening parenthesis or opening bracket + if operator in ('>', ')', ']'): + nesting_stack.append(operator) + elif operator in ('(', '['): + nesting_stack.pop() + + else: + # Scan the previous line + linenum -= 1 + if linenum < 0: + break + line = clean_lines.elided[linenum] + + # Exhausted all earlier lines and still no matching angle bracket. + return False + + +def CheckSpacing(filename, clean_lines, linenum, nesting_state, error): + """Checks for the correctness of various spacing issues in the code. + + Things we check for: spaces around operators, spaces after + if/for/while/switch, no spaces around parens in function calls, two + spaces between code and comment, don't start a block with a blank + line, don't end a function with a blank line, don't add a blank line + after public/protected/private, don't have too many blank lines in a row. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw = clean_lines.lines_without_raw_strings + line = raw[linenum] + + # Before nixing comments, check if the line is blank for no good + # reason. This includes the first line after a block is opened, and + # blank lines at the end of a function (ie, right before a line like '}' + # + # Skip all the blank line checks if we are immediately inside a + # namespace body. In other words, don't issue blank line warnings + # for this block: + # namespace { + # + # } + # + # A warning about missing end of namespace comments will be issued instead. + if IsBlankLine(line) and not nesting_state.InNamespaceBody(): + elided = clean_lines.elided + prev_line = elided[linenum - 1] + prevbrace = prev_line.rfind('{') + # TODO(unknown): Don't complain if line before blank line, and line after, + # both start with alnums and are indented the same amount. + # This ignores whitespace at the start of a namespace block + # because those are not usually indented. + if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1: + # OK, we have a blank line at the start of a code block. Before we + # complain, we check if it is an exception to the rule: The previous + # non-empty line has the parameters of a function header that are indented + # 4 spaces (because they did not fit in a 80 column line when placed on + # the same line as the function name). We also check for the case where + # the previous line is indented 6 spaces, which may happen when the + # initializers of a constructor do not fit into a 80 column line. + exception = False + if Match(r' {6}\w', prev_line): # Initializer list? + # We are looking for the opening column of initializer list, which + # should be indented 4 spaces to cause 6 space indentation afterwards. + search_position = linenum-2 + while (search_position >= 0 + and Match(r' {6}\w', elided[search_position])): + search_position -= 1 + exception = (search_position >= 0 + and elided[search_position][:5] == ' :') + else: + # Search for the function arguments or an initializer list. We use a + # simple heuristic here: If the line is indented 4 spaces; and we have a + # closing paren, without the opening paren, followed by an opening brace + # or colon (for initializer lists) we assume that it is the last line of + # a function header. If we have a colon indented 4 spaces, it is an + # initializer list. + exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)', + prev_line) + or Match(r' {4}:', prev_line)) + + if not exception: + error(filename, linenum, 'whitespace/blank_line', 2, + 'Redundant blank line at the start of a code block ' + 'should be deleted.') + # Ignore blank lines at the end of a block in a long if-else + # chain, like this: + # if (condition1) { + # // Something followed by a blank line + # + # } else if (condition2) { + # // Something else + # } + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + if (next_line + and Match(r'\s*}', next_line) + and next_line.find('} else ') == -1): + error(filename, linenum, 'whitespace/blank_line', 3, + 'Redundant blank line at the end of a code block ' + 'should be deleted.') + + matched = Match(r'\s*(public|protected|private):', prev_line) + if matched: + error(filename, linenum, 'whitespace/blank_line', 3, + 'Do not leave a blank line after "%s:"' % matched.group(1)) + + # Next, we complain if there's a comment too near the text + commentpos = line.find('//') + if commentpos != -1: + # Check if the // may be in quotes. If so, ignore it + # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison + if (line.count('"', 0, commentpos) - + line.count('\\"', 0, commentpos)) % 2 == 0: # not in quotes + # Allow one space for new scopes, two spaces otherwise: + if (not Match(r'^\s*{ //', line) and + ((commentpos >= 1 and + line[commentpos-1] not in string.whitespace) or + (commentpos >= 2 and + line[commentpos-2] not in string.whitespace))): + error(filename, linenum, 'whitespace/comments', 2, + 'At least two spaces is best between code and comments') + # There should always be a space between the // and the comment + commentend = commentpos + 2 + if commentend < len(line) and not line[commentend] == ' ': + # but some lines are exceptions -- e.g. if they're big + # comment delimiters like: + # //---------------------------------------------------------- + # or are an empty C++ style Doxygen comment, like: + # /// + # or C++ style Doxygen comments placed after the variable: + # ///< Header comment + # //!< Header comment + # or they begin with multiple slashes followed by a space: + # //////// Header comment + match = (Search(r'[=/-]{4,}\s*$', line[commentend:]) or + Search(r'^/$', line[commentend:]) or + Search(r'^!< ', line[commentend:]) or + Search(r'^/< ', line[commentend:]) or + Search(r'^/+ ', line[commentend:])) + if not match: + error(filename, linenum, 'whitespace/comments', 4, + 'Should have a space between // and comment') + CheckComment(line[commentpos:], filename, linenum, error) + + line = clean_lines.elided[linenum] # get rid of comments and strings + + # Don't try to do spacing checks for operator methods + line = re.sub(r'operator(==|!=|<|<<|<=|>=|>>|>)\(', 'operator\(', line) + + # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )". + # Otherwise not. Note we only check for non-spaces on *both* sides; + # sometimes people put non-spaces on one side when aligning ='s among + # many lines (not that this is behavior that I approve of...) + if Search(r'[\w.]=[\w.]', line) and not Search(r'\b(if|while) ', line): + error(filename, linenum, 'whitespace/operators', 4, + 'Missing spaces around =') + + # It's ok not to have spaces around binary operators like + - * /, but if + # there's too little whitespace, we get concerned. It's hard to tell, + # though, so we punt on this one for now. TODO. + + # You should always have whitespace around binary operators. + # + # Check <= and >= first to avoid false positives with < and >, then + # check non-include lines for spacing around < and >. + match = Search(r'[^<>=!\s](==|!=|<=|>=)[^<>=!\s]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around %s' % match.group(1)) + # We allow no-spaces around << when used like this: 10<<20, but + # not otherwise (particularly, not when used as streams) + # Also ignore using ns::operator<<; + match = Search(r'(operator|\S)(?:L|UL|ULL|l|ul|ull)?<<(\S)', line) + if (match and + not (match.group(1).isdigit() and match.group(2).isdigit()) and + not (match.group(1) == 'operator' and match.group(2) == ';')): + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <<') + elif not Match(r'#.*include', line): + # Avoid false positives on -> + reduced_line = line.replace('->', '') + + # Look for < that is not surrounded by spaces. This is only + # triggered if both sides are missing spaces, even though + # technically should should flag if at least one side is missing a + # space. This is done to avoid some false positives with shifts. + match = Search(r'[^\s<]<([^\s=<].*)', reduced_line) + if (match and + not FindNextMatchingAngleBracket(clean_lines, linenum, match.group(1))): + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <') + + # Look for > that is not surrounded by spaces. Similar to the + # above, we only trigger if both sides are missing spaces to avoid + # false positives with shifts. + match = Search(r'^(.*[^\s>])>[^\s=>]', reduced_line) + if (match and + not FindPreviousMatchingAngleBracket(clean_lines, linenum, + match.group(1))): + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >') + + # We allow no-spaces around >> for almost anything. This is because + # C++11 allows ">>" to close nested templates, which accounts for + # most cases when ">>" is not followed by a space. + # + # We still warn on ">>" followed by alpha character, because that is + # likely due to ">>" being used for right shifts, e.g.: + # value >> alpha + # + # When ">>" is used to close templates, the alphanumeric letter that + # follows would be part of an identifier, and there should still be + # a space separating the template type and the identifier. + # type> alpha + match = Search(r'>>[a-zA-Z_]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >>') + + # There shouldn't be space around unary operators + match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line) + if match: + error(filename, linenum, 'whitespace/operators', 4, + 'Extra space for operator %s' % match.group(1)) + + # A pet peeve of mine: no spaces after an if, while, switch, or for + match = Search(r' (if\(|for\(|while\(|switch\()', line) + if match: + error(filename, linenum, 'whitespace/parens', 5, + 'Missing space before ( in %s' % match.group(1)) + + # For if/for/while/switch, the left and right parens should be + # consistent about how many spaces are inside the parens, and + # there should either be zero or one spaces inside the parens. + # We don't want: "if ( foo)" or "if ( foo )". + # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed. + match = Search(r'\b(if|for|while|switch)\s*' + r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$', + line) + if match: + if len(match.group(2)) != len(match.group(4)): + if not (match.group(3) == ';' and + len(match.group(2)) == 1 + len(match.group(4)) or + not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)): + error(filename, linenum, 'whitespace/parens', 5, + 'Mismatching spaces inside () in %s' % match.group(1)) + if len(match.group(2)) not in [0, 1]: + error(filename, linenum, 'whitespace/parens', 5, + 'Should have zero or one spaces inside ( and ) in %s' % + match.group(1)) + + # You should always have a space after a comma (either as fn arg or operator) + # + # This does not apply when the non-space character following the + # comma is another comma, since the only time when that happens is + # for empty macro arguments. + # + # We run this check in two passes: first pass on elided lines to + # verify that lines contain missing whitespaces, second pass on raw + # lines to confirm that those missing whitespaces are not due to + # elided comments. + if Search(r',[^,\s]', line) and Search(r',[^,\s]', raw[linenum]): + error(filename, linenum, 'whitespace/comma', 3, + 'Missing space after ,') + + # You should always have a space after a semicolon + # except for few corner cases + # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more + # space after ; + if Search(r';[^\s};\\)/]', line): + error(filename, linenum, 'whitespace/semicolon', 3, + 'Missing space after ;') + + # Next we will look for issues with function calls. + CheckSpacingForFunctionCall(filename, line, linenum, error) + + # Except after an opening paren, or after another opening brace (in case of + # an initializer list, for instance), you should have spaces before your + # braces. And since you should never have braces at the beginning of a line, + # this is an easy test. + match = Match(r'^(.*[^ ({]){', line) + if match: + # Try a bit harder to check for brace initialization. This + # happens in one of the following forms: + # Constructor() : initializer_list_{} { ... } + # Constructor{}.MemberFunction() + # Type variable{}; + # FunctionCall(type{}, ...); + # LastArgument(..., type{}); + # LOG(INFO) << type{} << " ..."; + # map_of_type[{...}] = ...; + # + # We check for the character following the closing brace, and + # silence the warning if it's one of those listed above, i.e. + # "{.;,)<]". + # + # To account for nested initializer list, we allow any number of + # closing braces up to "{;,)<". We can't simply silence the + # warning on first sight of closing brace, because that would + # cause false negatives for things that are not initializer lists. + # Silence this: But not this: + # Outer{ if (...) { + # Inner{...} if (...){ // Missing space before { + # }; } + # + # There is a false negative with this approach if people inserted + # spurious semicolons, e.g. "if (cond){};", but we will catch the + # spurious semicolon with a separate check. + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + trailing_text = '' + if endpos > -1: + trailing_text = endline[endpos:] + for offset in xrange(endlinenum + 1, + min(endlinenum + 3, clean_lines.NumLines() - 1)): + trailing_text += clean_lines.elided[offset] + if not Match(r'^[\s}]*[{.;,)<\]]', trailing_text): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before {') + + # Make sure '} else {' has spaces. + if Search(r'}else', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before else') + + # You shouldn't have spaces before your brackets, except maybe after + # 'delete []' or 'new char * []'. + if Search(r'\w\s+\[', line) and not Search(r'delete\s+\[', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Extra space before [') + + # You shouldn't have a space before a semicolon at the end of the line. + # There's a special case for "for" since the style guide allows space before + # the semicolon there. + if Search(r':\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Semicolon defining empty statement. Use {} instead.') + elif Search(r'^\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Line contains only semicolon. If this should be an empty statement, ' + 'use {} instead.') + elif (Search(r'\s+;\s*$', line) and + not Search(r'\bfor\b', line)): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Extra space before last semicolon. If this should be an empty ' + 'statement, use {} instead.') + + # In range-based for, we wanted spaces before and after the colon, but + # not around "::" tokens that might appear. + if (Search('for *\(.*[^:]:[^: ]', line) or + Search('for *\(.*[^: ]:[^:]', line)): + error(filename, linenum, 'whitespace/forcolon', 2, + 'Missing space around colon in range-based for loop') + + +def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error): + """Checks for additional blank line issues related to sections. + + Currently the only thing checked here is blank line before protected/private. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + class_info: A _ClassInfo objects. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Skip checks if the class is small, where small means 25 lines or less. + # 25 lines seems like a good cutoff since that's the usual height of + # terminals, and any class that can't fit in one screen can't really + # be considered "small". + # + # Also skip checks if we are on the first line. This accounts for + # classes that look like + # class Foo { public: ... }; + # + # If we didn't find the end of the class, last_line would be zero, + # and the check will be skipped by the first condition. + if (class_info.last_line - class_info.starting_linenum <= 24 or + linenum <= class_info.starting_linenum): + return + + matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum]) + if matched: + # Issue warning if the line before public/protected/private was + # not a blank line, but don't do this if the previous line contains + # "class" or "struct". This can happen two ways: + # - We are at the beginning of the class. + # - We are forward-declaring an inner class that is semantically + # private, but needed to be public for implementation reasons. + # Also ignores cases where the previous line ends with a backslash as can be + # common when defining classes in C macros. + prev_line = clean_lines.lines[linenum - 1] + if (not IsBlankLine(prev_line) and + not Search(r'\b(class|struct)\b', prev_line) and + not Search(r'\\$', prev_line)): + # Try a bit harder to find the beginning of the class. This is to + # account for multi-line base-specifier lists, e.g.: + # class Derived + # : public Base { + end_class_head = class_info.starting_linenum + for i in range(class_info.starting_linenum, linenum): + if Search(r'\{\s*$', clean_lines.lines[i]): + end_class_head = i + break + if end_class_head < linenum - 1: + error(filename, linenum, 'whitespace/blank_line', 3, + '"%s:" should be preceded by a blank line' % matched.group(1)) + + +def GetPreviousNonBlankLine(clean_lines, linenum): + """Return the most recent non-blank line and its line number. + + Args: + clean_lines: A CleansedLines instance containing the file contents. + linenum: The number of the line to check. + + Returns: + A tuple with two elements. The first element is the contents of the last + non-blank line before the current line, or the empty string if this is the + first non-blank line. The second is the line number of that line, or -1 + if this is the first non-blank line. + """ + + prevlinenum = linenum - 1 + while prevlinenum >= 0: + prevline = clean_lines.elided[prevlinenum] + if not IsBlankLine(prevline): # if not a blank line... + return (prevline, prevlinenum) + prevlinenum -= 1 + return ('', -1) + + +def CheckBraces(filename, clean_lines, linenum, error): + """Looks for misplaced braces (e.g. at the end of line). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] # get rid of comments and strings + + if Match(r'\s*{\s*$', line): + # We allow an open brace to start a line in the case where someone is using + # braces in a block to explicitly create a new scope, which is commonly used + # to control the lifetime of stack-allocated variables. Braces are also + # used for brace initializers inside function calls. We don't detect this + # perfectly: we just don't complain if the last non-whitespace character on + # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the + # previous line starts a preprocessor block. + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if (not Search(r'[,;:}{(]\s*$', prevline) and + not Match(r'\s*#', prevline)): + error(filename, linenum, 'whitespace/braces', 4, + '{ should almost always be at the end of the previous line') + + # An else clause should be on the same line as the preceding closing brace. + if Match(r'\s*else\s*', line): + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if Match(r'\s*}\s*$', prevline): + error(filename, linenum, 'whitespace/newline', 4, + 'An else should appear on the same line as the preceding }') + + # If braces come on one side of an else, they should be on both. + # However, we have to worry about "else if" that spans multiple lines! + if Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line): + if Search(r'}\s*else if([^{]*)$', line): # could be multi-line if + # find the ( after the if + pos = line.find('else if') + pos = line.find('(', pos) + if pos > 0: + (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos) + if endline[endpos:].find('{') == -1: # must be brace after if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + else: # common case: else not followed by a multi-line if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + + # Likewise, an else should never have the else clause on the same line + if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line): + error(filename, linenum, 'whitespace/newline', 4, + 'Else clause should never be on same line as else (use 2 lines)') + + # In the same way, a do/while should never be on one line + if Match(r'\s*do [^\s{]', line): + error(filename, linenum, 'whitespace/newline', 4, + 'do/while clauses should not be on a single line') + + # Block bodies should not be followed by a semicolon. Due to C++11 + # brace initialization, there are more places where semicolons are + # required than not, so we use a whitelist approach to check these + # rather than a blacklist. These are the places where "};" should + # be replaced by just "}": + # 1. Some flavor of block following closing parenthesis: + # for (;;) {}; + # while (...) {}; + # switch (...) {}; + # Function(...) {}; + # if (...) {}; + # if (...) else if (...) {}; + # + # 2. else block: + # if (...) else {}; + # + # 3. const member function: + # Function(...) const {}; + # + # 4. Block following some statement: + # x = 42; + # {}; + # + # 5. Block at the beginning of a function: + # Function(...) { + # {}; + # } + # + # Note that naively checking for the preceding "{" will also match + # braces inside multi-dimensional arrays, but this is fine since + # that expression will not contain semicolons. + # + # 6. Block following another block: + # while (true) {} + # {}; + # + # 7. End of namespaces: + # namespace {}; + # + # These semicolons seems far more common than other kinds of + # redundant semicolons, possibly due to people converting classes + # to namespaces. For now we do not warn for this case. + # + # Try matching case 1 first. + match = Match(r'^(.*\)\s*)\{', line) + if match: + # Matched closing parenthesis (case 1). Check the token before the + # matching opening parenthesis, and don't warn if it looks like a + # macro. This avoids these false positives: + # - macro that defines a base class + # - multi-line macro that defines a base class + # - macro that defines the whole class-head + # + # But we still issue warnings for macros that we know are safe to + # warn, specifically: + # - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P + # - TYPED_TEST + # - INTERFACE_DEF + # - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED: + # + # We implement a whitelist of safe macros instead of a blacklist of + # unsafe macros, even though the latter appears less frequently in + # google code and would have been easier to implement. This is because + # the downside for getting the whitelist wrong means some extra + # semicolons, while the downside for getting the blacklist wrong + # would result in compile errors. + # + # In addition to macros, we also don't want to warn on compound + # literals. + closing_brace_pos = match.group(1).rfind(')') + opening_parenthesis = ReverseCloseExpression( + clean_lines, linenum, closing_brace_pos) + if opening_parenthesis[2] > -1: + line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]] + macro = Search(r'\b([A-Z_]+)\s*$', line_prefix) + if ((macro and + macro.group(1) not in ( + 'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST', + 'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED', + 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or + Search(r'\s+=\s*$', line_prefix)): + match = None + + else: + # Try matching cases 2-3. + match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line) + if not match: + # Try matching cases 4-6. These are always matched on separate lines. + # + # Note that we can't simply concatenate the previous line to the + # current line and do a single match, otherwise we may output + # duplicate warnings for the blank line case: + # if (cond) { + # // blank line + # } + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if prevline and Search(r'[;{}]\s*$', prevline): + match = Match(r'^(\s*)\{', line) + + # Check matching closing brace + if match: + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + if endpos > -1 and Match(r'^\s*;', endline[endpos:]): + # Current {} pair is eligible for semicolon check, and we have found + # the redundant semicolon, output warning here. + # + # Note: because we are scanning forward for opening braces, and + # outputting warnings for the matching closing brace, if there are + # nested blocks with trailing semicolons, we will get the error + # messages in reversed order. + error(filename, endlinenum, 'readability/braces', 4, + "You don't need a ; after a }") + + +def CheckEmptyBlockBody(filename, clean_lines, linenum, error): + """Look for empty loop/conditional body with only a single semicolon. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Search for loop keywords at the beginning of the line. Because only + # whitespaces are allowed before the keywords, this will also ignore most + # do-while-loops, since those lines should start with closing brace. + # + # We also check "if" blocks here, since an empty conditional block + # is likely an error. + line = clean_lines.elided[linenum] + matched = Match(r'\s*(for|while|if)\s*\(', line) + if matched: + # Find the end of the conditional expression + (end_line, end_linenum, end_pos) = CloseExpression( + clean_lines, linenum, line.find('(')) + + # Output warning if what follows the condition expression is a semicolon. + # No warning for all other cases, including whitespace or newline, since we + # have a separate check for semicolons preceded by whitespace. + if end_pos >= 0 and Match(r';', end_line[end_pos:]): + if matched.group(1) == 'if': + error(filename, end_linenum, 'whitespace/empty_conditional_body', 5, + 'Empty conditional bodies should use {}') + else: + error(filename, end_linenum, 'whitespace/empty_loop_body', 5, + 'Empty loop bodies should use {} or continue') + + +def CheckCheck(filename, clean_lines, linenum, error): + """Checks the use of CHECK and EXPECT macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Decide the set of replacement macros that should be suggested + lines = clean_lines.elided + check_macro = None + start_pos = -1 + for macro in _CHECK_MACROS: + i = lines[linenum].find(macro) + if i >= 0: + check_macro = macro + + # Find opening parenthesis. Do a regular expression match here + # to make sure that we are matching the expected CHECK macro, as + # opposed to some other macro that happens to contain the CHECK + # substring. + matched = Match(r'^(.*\b' + check_macro + r'\s*)\(', lines[linenum]) + if not matched: + continue + start_pos = len(matched.group(1)) + break + if not check_macro or start_pos < 0: + # Don't waste time here if line doesn't contain 'CHECK' or 'EXPECT' + return + + # Find end of the boolean expression by matching parentheses + (last_line, end_line, end_pos) = CloseExpression( + clean_lines, linenum, start_pos) + if end_pos < 0: + return + if linenum == end_line: + expression = lines[linenum][start_pos + 1:end_pos - 1] + else: + expression = lines[linenum][start_pos + 1:] + for i in xrange(linenum + 1, end_line): + expression += lines[i] + expression += last_line[0:end_pos - 1] + + # Parse expression so that we can take parentheses into account. + # This avoids false positives for inputs like "CHECK((a < 4) == b)", + # which is not replaceable by CHECK_LE. + lhs = '' + rhs = '' + operator = None + while expression: + matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||' + r'==|!=|>=|>|<=|<|\()(.*)$', expression) + if matched: + token = matched.group(1) + if token == '(': + # Parenthesized operand + expression = matched.group(2) + (end, _) = FindEndOfExpressionInLine(expression, 0, 1, '(', ')') + if end < 0: + return # Unmatched parenthesis + lhs += '(' + expression[0:end] + expression = expression[end:] + elif token in ('&&', '||'): + # Logical and/or operators. This means the expression + # contains more than one term, for example: + # CHECK(42 < a && a < b); + # + # These are not replaceable with CHECK_LE, so bail out early. + return + elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'): + # Non-relational operator + lhs += token + expression = matched.group(2) + else: + # Relational operator + operator = token + rhs = matched.group(2) + break + else: + # Unparenthesized operand. Instead of appending to lhs one character + # at a time, we do another regular expression match to consume several + # characters at once if possible. Trivial benchmark shows that this + # is more efficient when the operands are longer than a single + # character, which is generally the case. + matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression) + if not matched: + matched = Match(r'^(\s*\S)(.*)$', expression) + if not matched: + break + lhs += matched.group(1) + expression = matched.group(2) + + # Only apply checks if we got all parts of the boolean expression + if not (lhs and operator and rhs): + return + + # Check that rhs do not contain logical operators. We already know + # that lhs is fine since the loop above parses out && and ||. + if rhs.find('&&') > -1 or rhs.find('||') > -1: + return + + # At least one of the operands must be a constant literal. This is + # to avoid suggesting replacements for unprintable things like + # CHECK(variable != iterator) + # + # The following pattern matches decimal, hex integers, strings, and + # characters (in that order). + lhs = lhs.strip() + rhs = rhs.strip() + match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$' + if Match(match_constant, lhs) or Match(match_constant, rhs): + # Note: since we know both lhs and rhs, we can provide a more + # descriptive error message like: + # Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42) + # Instead of: + # Consider using CHECK_EQ instead of CHECK(a == b) + # + # We are still keeping the less descriptive message because if lhs + # or rhs gets long, the error message might become unreadable. + error(filename, linenum, 'readability/check', 2, + 'Consider using %s instead of %s(a %s b)' % ( + _CHECK_REPLACEMENT[check_macro][operator], + check_macro, operator)) + + +def CheckAltTokens(filename, clean_lines, linenum, error): + """Check alternative keywords being used in boolean expressions. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Avoid preprocessor lines + if Match(r'^\s*#', line): + return + + # Last ditch effort to avoid multi-line comments. This will not help + # if the comment started before the current line or ended after the + # current line, but it catches most of the false positives. At least, + # it provides a way to workaround this warning for people who use + # multi-line comments in preprocessor macros. + # + # TODO(unknown): remove this once cpplint has better support for + # multi-line comments. + if line.find('/*') >= 0 or line.find('*/') >= 0: + return + + for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line): + error(filename, linenum, 'readability/alt_tokens', 2, + 'Use operator %s instead of %s' % ( + _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1))) + + +def GetLineWidth(line): + """Determines the width of the line in column positions. + + Args: + line: A string, which may be a Unicode string. + + Returns: + The width of the line in column positions, accounting for Unicode + combining characters and wide characters. + """ + if isinstance(line, unicode): + width = 0 + for uc in unicodedata.normalize('NFC', line): + if unicodedata.east_asian_width(uc) in ('W', 'F'): + width += 2 + elif not unicodedata.combining(uc): + width += 1 + return width + else: + return len(line) + + +def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, + error): + """Checks rules from the 'C++ style rules' section of cppguide.html. + + Most of these rules are hard to test (naming, comment style), but we + do what we can. In particular we check for 2-space indents, line lengths, + tab usage, spaces inside code, etc. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw_lines = clean_lines.lines_without_raw_strings + line = raw_lines[linenum] + + if line.find('\t') != -1: + error(filename, linenum, 'whitespace/tab', 1, + 'Tab found; better to use spaces') + + # One or three blank spaces at the beginning of the line is weird; it's + # hard to reconcile that with 2-space indents. + # NOTE: here are the conditions rob pike used for his tests. Mine aren't + # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces + # if(RLENGTH > 20) complain = 0; + # if(match($0, " +(error|private|public|protected):")) complain = 0; + # if(match(prev, "&& *$")) complain = 0; + # if(match(prev, "\\|\\| *$")) complain = 0; + # if(match(prev, "[\",=><] *$")) complain = 0; + # if(match($0, " <<")) complain = 0; + # if(match(prev, " +for \\(")) complain = 0; + # if(prevodd && match(prevprev, " +for \\(")) complain = 0; + initial_spaces = 0 + cleansed_line = clean_lines.elided[linenum] + while initial_spaces < len(line) and line[initial_spaces] == ' ': + initial_spaces += 1 + if line and line[-1].isspace(): + error(filename, linenum, 'whitespace/end_of_line', 4, + 'Line ends in whitespace. Consider deleting these extra spaces.') + # There are certain situations we allow one space, notably for section labels + elif ((initial_spaces == 1 or initial_spaces == 3) and + not Match(r'\s*\w+\s*:\s*$', cleansed_line)): + error(filename, linenum, 'whitespace/indent', 3, + 'Weird number of spaces at line-start. ' + 'Are you using a 2-space indent?') + + # Check if the line is a header guard. + is_header_guard = False + if file_extension == 'h': + cppvar = GetHeaderGuardCPPVariable(filename) + if (line.startswith('#ifndef %s' % cppvar) or + line.startswith('#define %s' % cppvar) or + line.startswith('#endif // %s' % cppvar)): + is_header_guard = True + # #include lines and header guards can be long, since there's no clean way to + # split them. + # + # URLs can be long too. It's possible to split these, but it makes them + # harder to cut&paste. + # + # The "$Id:...$" comment may also get very long without it being the + # developers fault. + if (not line.startswith('#include') and not is_header_guard and + not Match(r'^\s*//.*http(s?)://\S*$', line) and + not Match(r'^// \$Id:.*#[0-9]+ \$$', line)): + line_width = GetLineWidth(line) + extended_length = int((_line_length * 1.25)) + if line_width > extended_length: + error(filename, linenum, 'whitespace/line_length', 4, + 'Lines should very rarely be longer than %i characters' % + extended_length) + elif line_width > _line_length: + error(filename, linenum, 'whitespace/line_length', 2, + 'Lines should be <= %i characters long' % _line_length) + + if (cleansed_line.count(';') > 1 and + # for loops are allowed two ;'s (and may run over two lines). + cleansed_line.find('for') == -1 and + (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or + GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and + # It's ok to have many commands in a switch case that fits in 1 line + not ((cleansed_line.find('case ') != -1 or + cleansed_line.find('default:') != -1) and + cleansed_line.find('break;') != -1)): + error(filename, linenum, 'whitespace/newline', 0, + 'More than one command on the same line') + + # Some more style checks + CheckBraces(filename, clean_lines, linenum, error) + CheckEmptyBlockBody(filename, clean_lines, linenum, error) + CheckAccess(filename, clean_lines, linenum, nesting_state, error) + CheckSpacing(filename, clean_lines, linenum, nesting_state, error) + CheckCheck(filename, clean_lines, linenum, error) + CheckAltTokens(filename, clean_lines, linenum, error) + classinfo = nesting_state.InnermostClass() + if classinfo: + CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error) + + +_RE_PATTERN_INCLUDE_NEW_STYLE = re.compile(r'#include +"[^/]+\.h"') +_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$') +# Matches the first component of a filename delimited by -s and _s. That is: +# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo' +_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+') + + +def _DropCommonSuffixes(filename): + """Drops common suffixes like _test.cc or -inl.h from filename. + + For example: + >>> _DropCommonSuffixes('foo/foo-inl.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/bar/foo.cc') + 'foo/bar/foo' + >>> _DropCommonSuffixes('foo/foo_internal.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/foo_unusualinternal.h') + 'foo/foo_unusualinternal' + + Args: + filename: The input filename. + + Returns: + The filename with the common suffix removed. + """ + for suffix in ('test.cc', 'regtest.cc', 'unittest.cc', + 'inl.h', 'impl.h', 'internal.h'): + if (filename.endswith(suffix) and len(filename) > len(suffix) and + filename[-len(suffix) - 1] in ('-', '_')): + return filename[:-len(suffix) - 1] + return os.path.splitext(filename)[0] + + +def _IsTestFilename(filename): + """Determines if the given filename has a suffix that identifies it as a test. + + Args: + filename: The input filename. + + Returns: + True if 'filename' looks like a test, False otherwise. + """ + if (filename.endswith('_test.cc') or + filename.endswith('_unittest.cc') or + filename.endswith('_regtest.cc')): + return True + else: + return False + + +def _ClassifyInclude(fileinfo, include, is_system): + """Figures out what kind of header 'include' is. + + Args: + fileinfo: The current file cpplint is running over. A FileInfo instance. + include: The path to a #included file. + is_system: True if the #include used <> rather than "". + + Returns: + One of the _XXX_HEADER constants. + + For example: + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True) + _C_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True) + _CPP_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False) + _LIKELY_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'), + ... 'bar/foo_other_ext.h', False) + _POSSIBLE_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False) + _OTHER_HEADER + """ + # This is a list of all standard c++ header files, except + # those already checked for above. + is_cpp_h = include in _CPP_HEADERS + + if is_system: + if is_cpp_h: + return _CPP_SYS_HEADER + else: + return _C_SYS_HEADER + + # If the target file and the include we're checking share a + # basename when we drop common extensions, and the include + # lives in . , then it's likely to be owned by the target file. + target_dir, target_base = ( + os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName()))) + include_dir, include_base = os.path.split(_DropCommonSuffixes(include)) + if target_base == include_base and ( + include_dir == target_dir or + include_dir == os.path.normpath(target_dir + '/../public')): + return _LIKELY_MY_HEADER + + # If the target and include share some initial basename + # component, it's possible the target is implementing the + # include, so it's allowed to be first, but we'll never + # complain if it's not there. + target_first_component = _RE_FIRST_COMPONENT.match(target_base) + include_first_component = _RE_FIRST_COMPONENT.match(include_base) + if (target_first_component and include_first_component and + target_first_component.group(0) == + include_first_component.group(0)): + return _POSSIBLE_MY_HEADER + + return _OTHER_HEADER + + + +def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): + """Check rules that are applicable to #include lines. + + Strings on #include lines are NOT removed from elided line, to make + certain tasks easier. However, to prevent false positives, checks + applicable to #include lines in CheckLanguage must be put here. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + include_state: An _IncludeState instance in which the headers are inserted. + error: The function to call with any errors found. + """ + fileinfo = FileInfo(filename) + + line = clean_lines.lines[linenum] + + # "include" should use the new style "foo/bar.h" instead of just "bar.h" + if _RE_PATTERN_INCLUDE_NEW_STYLE.search(line): + error(filename, linenum, 'build/include_dir', 4, + 'Include the directory when naming .h files') + + # we shouldn't include a file more than once. actually, there are a + # handful of instances where doing so is okay, but in general it's + # not. + match = _RE_PATTERN_INCLUDE.search(line) + if match: + include = match.group(2) + is_system = (match.group(1) == '<') + if include in include_state: + error(filename, linenum, 'build/include', 4, + '"%s" already included at %s:%s' % + (include, filename, include_state[include])) + else: + include_state[include] = linenum + + # We want to ensure that headers appear in the right order: + # 1) for foo.cc, foo.h (preferred location) + # 2) c system files + # 3) cpp system files + # 4) for foo.cc, foo.h (deprecated location) + # 5) other google headers + # + # We classify each include statement as one of those 5 types + # using a number of techniques. The include_state object keeps + # track of the highest type seen, and complains if we see a + # lower type after that. + error_message = include_state.CheckNextIncludeOrder( + _ClassifyInclude(fileinfo, include, is_system)) + if error_message: + error(filename, linenum, 'build/include_order', 4, + '%s. Should be: %s.h, c system, c++ system, other.' % + (error_message, fileinfo.BaseName())) + canonical_include = include_state.CanonicalizeAlphabeticalOrder(include) + if not include_state.IsInAlphabeticalOrder( + clean_lines, linenum, canonical_include): + error(filename, linenum, 'build/include_alpha', 4, + 'Include "%s" not in alphabetical order' % include) + include_state.SetLastHeader(canonical_include) + + # Look for any of the stream classes that are part of standard C++. + match = _RE_PATTERN_INCLUDE.match(line) + if match: + include = match.group(2) + if Match(r'(f|ind|io|i|o|parse|pf|stdio|str|)?stream$', include): + # Many unit tests use cout, so we exempt them. + if not _IsTestFilename(filename): + error(filename, linenum, 'readability/streams', 3, + 'Streams are highly discouraged.') + + +def _GetTextInside(text, start_pattern): + r"""Retrieves all the text between matching open and close parentheses. + + Given a string of lines and a regular expression string, retrieve all the text + following the expression and between opening punctuation symbols like + (, [, or {, and the matching close-punctuation symbol. This properly nested + occurrences of the punctuations, so for the text like + printf(a(), b(c())); + a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'. + start_pattern must match string having an open punctuation symbol at the end. + + Args: + text: The lines to extract text. Its comments and strings must be elided. + It can be single line and can span multiple lines. + start_pattern: The regexp string indicating where to start extracting + the text. + Returns: + The extracted text. + None if either the opening string or ending punctuation could not be found. + """ + # TODO(sugawarayu): Audit cpplint.py to see what places could be profitably + # rewritten to use _GetTextInside (and use inferior regexp matching today). + + # Give opening punctuations to get the matching close-punctuations. + matching_punctuation = {'(': ')', '{': '}', '[': ']'} + closing_punctuation = set(matching_punctuation.itervalues()) + + # Find the position to start extracting text. + match = re.search(start_pattern, text, re.M) + if not match: # start_pattern not found in text. + return None + start_position = match.end(0) + + assert start_position > 0, ( + 'start_pattern must ends with an opening punctuation.') + assert text[start_position - 1] in matching_punctuation, ( + 'start_pattern must ends with an opening punctuation.') + # Stack of closing punctuations we expect to have in text after position. + punctuation_stack = [matching_punctuation[text[start_position - 1]]] + position = start_position + while punctuation_stack and position < len(text): + if text[position] == punctuation_stack[-1]: + punctuation_stack.pop() + elif text[position] in closing_punctuation: + # A closing punctuation without matching opening punctuations. + return None + elif text[position] in matching_punctuation: + punctuation_stack.append(matching_punctuation[text[position]]) + position += 1 + if punctuation_stack: + # Opening punctuations left without matching close-punctuations. + return None + # punctuations match. + return text[start_position:position - 1] + + +# Patterns for matching call-by-reference parameters. +# +# Supports nested templates up to 2 levels deep using this messy pattern: +# < (?: < (?: < [^<>]* +# > +# | [^<>] )* +# > +# | [^<>] )* +# > +_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*' # =~ [[:alpha:]][[:alnum:]]* +_RE_PATTERN_TYPE = ( + r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?' + r'(?:\w|' + r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|' + r'::)+') +# A call-by-reference parameter ends with '& identifier'. +_RE_PATTERN_REF_PARAM = re.compile( + r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*' + r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]') +# A call-by-const-reference parameter either ends with 'const& identifier' +# or looks like 'const type& identifier' when 'type' is atomic. +_RE_PATTERN_CONST_REF_PARAM = ( + r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT + + r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')') + + +def CheckLanguage(filename, clean_lines, linenum, file_extension, + include_state, nesting_state, error): + """Checks rules from the 'C++ language rules' section of cppguide.html. + + Some of these rules are hard to test (function overloading, using + uint32 inappropriately), but we do the best we can. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + include_state: An _IncludeState instance in which the headers are inserted. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # If the line is empty or consists of entirely a comment, no need to + # check it. + line = clean_lines.elided[linenum] + if not line: + return + + match = _RE_PATTERN_INCLUDE.search(line) + if match: + CheckIncludeLine(filename, clean_lines, linenum, include_state, error) + return + + # Reset include state across preprocessor directives. This is meant + # to silence warnings for conditional includes. + if Match(r'^\s*#\s*(?:ifdef|elif|else|endif)\b', line): + include_state.ResetSection() + + # Make Windows paths like Unix. + fullname = os.path.abspath(filename).replace('\\', '/') + + # TODO(unknown): figure out if they're using default arguments in fn proto. + + # Check to see if they're using an conversion function cast. + # I just try to capture the most common basic types, though there are more. + # Parameterless conversion functions, such as bool(), are allowed as they are + # probably a member operator declaration or default constructor. + match = Search( + r'(\bnew\s+)?\b' # Grab 'new' operator, if it's there + r'(int|float|double|bool|char|int32|uint32|int64|uint64)' + r'(\([^)].*)', line) + if match: + matched_new = match.group(1) + matched_type = match.group(2) + matched_funcptr = match.group(3) + + # gMock methods are defined using some variant of MOCK_METHODx(name, type) + # where type may be float(), int(string), etc. Without context they are + # virtually indistinguishable from int(x) casts. Likewise, gMock's + # MockCallback takes a template parameter of the form return_type(arg_type), + # which looks much like the cast we're trying to detect. + # + # std::function<> wrapper has a similar problem. + # + # Return types for function pointers also look like casts if they + # don't have an extra space. + if (matched_new is None and # If new operator, then this isn't a cast + not (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or + Search(r'\bMockCallback<.*>', line) or + Search(r'\bstd::function<.*>', line)) and + not (matched_funcptr and + Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(', + matched_funcptr))): + # Try a bit harder to catch gmock lines: the only place where + # something looks like an old-style cast is where we declare the + # return type of the mocked method, and the only time when we + # are missing context is if MOCK_METHOD was split across + # multiple lines. The missing MOCK_METHOD is usually one or two + # lines back, so scan back one or two lines. + # + # It's not possible for gmock macros to appear in the first 2 + # lines, since the class head + section name takes up 2 lines. + if (linenum < 2 or + not (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$', + clean_lines.elided[linenum - 1]) or + Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$', + clean_lines.elided[linenum - 2]))): + error(filename, linenum, 'readability/casting', 4, + 'Using deprecated casting style. ' + 'Use static_cast<%s>(...) instead' % + matched_type) + + CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'static_cast', + r'\((int|float|double|bool|char|u?int(16|32|64))\)', error) + + # This doesn't catch all cases. Consider (const char * const)"hello". + # + # (char *) "foo" should always be a const_cast (reinterpret_cast won't + # compile). + if CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'const_cast', r'\((char\s?\*+\s?)\)\s*"', error): + pass + else: + # Check pointer casts for other than string constants + CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'reinterpret_cast', r'\((\w+\s?\*+\s?)\)', error) + + # In addition, we look for people taking the address of a cast. This + # is dangerous -- casts can assign to temporaries, so the pointer doesn't + # point where you think. + match = Search( + r'(?:&\(([^)]+)\)[\w(])|' + r'(?:&(static|dynamic|down|reinterpret)_cast\b)', line) + if match and match.group(1) != '*': + error(filename, linenum, 'runtime/casting', 4, + ('Are you taking an address of a cast? ' + 'This is dangerous: could be a temp var. ' + 'Take the address before doing the cast, rather than after')) + + # Create an extended_line, which is the concatenation of the current and + # next lines, for more effective checking of code that may span more than one + # line. + if linenum + 1 < clean_lines.NumLines(): + extended_line = line + clean_lines.elided[linenum + 1] + else: + extended_line = line + + # Check for people declaring static/global STL strings at the top level. + # This is dangerous because the C++ language does not guarantee that + # globals with constructors are initialized before the first access. + match = Match( + r'((?:|static +)(?:|const +))string +([a-zA-Z0-9_:]+)\b(.*)', + line) + # Make sure it's not a function. + # Function template specialization looks like: "string foo(...". + # Class template definitions look like: "string Foo::Method(...". + # + # Also ignore things that look like operators. These are matched separately + # because operator names cross non-word boundaries. If we change the pattern + # above, we would decrease the accuracy of matching identifiers. + if (match and + not Search(r'\boperator\W', line) and + not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)?\s*\(([^"]|$)', match.group(3))): + error(filename, linenum, 'runtime/string', 4, + 'For a static/global string constant, use a C style string instead: ' + '"%schar %s[]".' % + (match.group(1), match.group(2))) + + if Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line): + error(filename, linenum, 'runtime/init', 4, + 'You seem to be initializing a member variable with itself.') + + if file_extension == 'h': + # TODO(unknown): check that 1-arg constructors are explicit. + # How to tell it's a constructor? + # (handled in CheckForNonStandardConstructs for now) + # TODO(unknown): check that classes have DISALLOW_EVIL_CONSTRUCTORS + # (level 1 error) + pass + + # Check if people are using the verboten C basic types. The only exception + # we regularly allow is "unsigned short port" for port. + if Search(r'\bshort port\b', line): + if not Search(r'\bunsigned short port\b', line): + error(filename, linenum, 'runtime/int', 4, + 'Use "unsigned short" for ports, not "short"') + else: + match = Search(r'\b(short|long(?! +double)|long long)\b', line) + if match: + error(filename, linenum, 'runtime/int', 4, + 'Use int16/int64/etc, rather than the C type %s' % match.group(1)) + + # When snprintf is used, the second argument shouldn't be a literal. + match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line) + if match and match.group(2) != '0': + # If 2nd arg is zero, snprintf is used to calculate size. + error(filename, linenum, 'runtime/printf', 3, + 'If you can, use sizeof(%s) instead of %s as the 2nd arg ' + 'to snprintf.' % (match.group(1), match.group(2))) + + # Check if some verboten C functions are being used. + if Search(r'\bsprintf\b', line): + error(filename, linenum, 'runtime/printf', 5, + 'Never use sprintf. Use snprintf instead.') + match = Search(r'\b(strcpy|strcat)\b', line) + if match: + error(filename, linenum, 'runtime/printf', 4, + 'Almost always, snprintf is better than %s' % match.group(1)) + + # Check if some verboten operator overloading is going on + # TODO(unknown): catch out-of-line unary operator&: + # class X {}; + # int operator&(const X& x) { return 42; } // unary operator& + # The trick is it's hard to tell apart from binary operator&: + # class Y { int operator&(const Y& x) { return 23; } }; // binary operator& + if Search(r'\boperator\s*&\s*\(\s*\)', line): + error(filename, linenum, 'runtime/operator', 4, + 'Unary operator& is dangerous. Do not use it.') + + # Check for suspicious usage of "if" like + # } if (a == b) { + if Search(r'\}\s*if\s*\(', line): + error(filename, linenum, 'readability/braces', 4, + 'Did you mean "else if"? If not, start a new line for "if".') + + # Check for potential format string bugs like printf(foo). + # We constrain the pattern not to pick things like DocidForPrintf(foo). + # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str()) + # TODO(sugawarayu): Catch the following case. Need to change the calling + # convention of the whole function to process multiple line to handle it. + # printf( + # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line); + printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(') + if printf_args: + match = Match(r'([\w.\->()]+)$', printf_args) + if match and match.group(1) != '__VA_ARGS__': + function_name = re.search(r'\b((?:string)?printf)\s*\(', + line, re.I).group(1) + error(filename, linenum, 'runtime/printf', 4, + 'Potential format string bug. Do %s("%%s", %s) instead.' + % (function_name, match.group(1))) + + # Check for potential memset bugs like memset(buf, sizeof(buf), 0). + match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line) + if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)): + error(filename, linenum, 'runtime/memset', 4, + 'Did you mean "memset(%s, 0, %s)"?' + % (match.group(1), match.group(2))) + + if Search(r'\busing namespace\b', line): + error(filename, linenum, 'build/namespaces', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') + + # Detect variable-length arrays. + match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line) + if (match and match.group(2) != 'return' and match.group(2) != 'delete' and + match.group(3).find(']') == -1): + # Split the size using space and arithmetic operators as delimiters. + # If any of the resulting tokens are not compile time constants then + # report the error. + tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3)) + is_const = True + skip_next = False + for tok in tokens: + if skip_next: + skip_next = False + continue + + if Search(r'sizeof\(.+\)', tok): continue + if Search(r'arraysize\(\w+\)', tok): continue + + tok = tok.lstrip('(') + tok = tok.rstrip(')') + if not tok: continue + if Match(r'\d+', tok): continue + if Match(r'0[xX][0-9a-fA-F]+', tok): continue + if Match(r'k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue + # A catch all for tricky sizeof cases, including 'sizeof expression', + # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)' + # requires skipping the next token because we split on ' ' and '*'. + if tok.startswith('sizeof'): + skip_next = True + continue + is_const = False + break + if not is_const: + error(filename, linenum, 'runtime/arrays', 1, + 'Do not use variable-length arrays. Use an appropriately named ' + "('k' followed by CamelCase) compile-time constant for the size.") + + # If DISALLOW_EVIL_CONSTRUCTORS, DISALLOW_COPY_AND_ASSIGN, or + # DISALLOW_IMPLICIT_CONSTRUCTORS is present, then it should be the last thing + # in the class declaration. + match = Match( + (r'\s*' + r'(DISALLOW_(EVIL_CONSTRUCTORS|COPY_AND_ASSIGN|IMPLICIT_CONSTRUCTORS))' + r'\(.*\);$'), + line) + if match and linenum + 1 < clean_lines.NumLines(): + next_line = clean_lines.elided[linenum + 1] + # We allow some, but not all, declarations of variables to be present + # in the statement that defines the class. The [\w\*,\s]* fragment of + # the regular expression below allows users to declare instances of + # the class or pointers to instances, but not less common types such + # as function pointers or arrays. It's a tradeoff between allowing + # reasonable code and avoiding trying to parse more C++ using regexps. + if not Search(r'^\s*}[\w\*,\s]*;', next_line): + error(filename, linenum, 'readability/constructors', 3, + match.group(1) + ' should be the last thing in the class') + + # Check for use of unnamed namespaces in header files. Registration + # macros are typically OK, so we allow use of "namespace {" on lines + # that end with backslashes. + if (file_extension == 'h' + and Search(r'\bnamespace\s*{', line) + and line[-1] != '\\'): + error(filename, linenum, 'build/namespaces', 4, + 'Do not use unnamed namespaces in header files. See ' + 'http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' + ' for more information.') + +def CheckForNonConstReference(filename, clean_lines, linenum, + nesting_state, error): + """Check for non-const references. + + Separate from CheckLanguage since it scans backwards from current + line, instead of scanning forward. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # Do nothing if there is no '&' on current line. + line = clean_lines.elided[linenum] + if '&' not in line: + return + + # Long type names may be broken across multiple lines, usually in one + # of these forms: + # LongType + # ::LongTypeContinued &identifier + # LongType:: + # LongTypeContinued &identifier + # LongType< + # ...>::LongTypeContinued &identifier + # + # If we detected a type split across two lines, join the previous + # line to current line so that we can match const references + # accordingly. + # + # Note that this only scans back one line, since scanning back + # arbitrary number of lines would be expensive. If you have a type + # that spans more than 2 lines, please use a typedef. + if linenum > 1: + previous = None + if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line): + # previous_line\n + ::current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$', + clean_lines.elided[linenum - 1]) + elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line): + # previous_line::\n + current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$', + clean_lines.elided[linenum - 1]) + if previous: + line = previous.group(1) + line.lstrip() + else: + # Check for templated parameter that is split across multiple lines + endpos = line.rfind('>') + if endpos > -1: + (_, startline, startpos) = ReverseCloseExpression( + clean_lines, linenum, endpos) + if startpos > -1 and startline < linenum: + # Found the matching < on an earlier line, collect all + # pieces up to current line. + line = '' + for i in xrange(startline, linenum + 1): + line += clean_lines.elided[i].strip() + + # Check for non-const references in function parameters. A single '&' may + # found in the following places: + # inside expression: binary & for bitwise AND + # inside expression: unary & for taking the address of something + # inside declarators: reference parameter + # We will exclude the first two cases by checking that we are not inside a + # function body, including one that was just introduced by a trailing '{'. + # TODO(unknwon): Doesn't account for preprocessor directives. + # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare]. + check_params = False + if not nesting_state.stack: + check_params = True # top level + elif (isinstance(nesting_state.stack[-1], _ClassInfo) or + isinstance(nesting_state.stack[-1], _NamespaceInfo)): + check_params = True # within class or namespace + elif Match(r'.*{\s*$', line): + if (len(nesting_state.stack) == 1 or + isinstance(nesting_state.stack[-2], _ClassInfo) or + isinstance(nesting_state.stack[-2], _NamespaceInfo)): + check_params = True # just opened global/class/namespace block + # We allow non-const references in a few standard places, like functions + # called "swap()" or iostream operators like "<<" or ">>". Do not check + # those function parameters. + # + # We also accept & in static_assert, which looks like a function but + # it's actually a declaration expression. + whitelisted_functions = (r'(?:[sS]wap(?:<\w:+>)?|' + r'operator\s*[<>][<>]|' + r'static_assert|COMPILE_ASSERT' + r')\s*\(') + if Search(whitelisted_functions, line): + check_params = False + elif not Search(r'\S+\([^)]*$', line): + # Don't see a whitelisted function on this line. Actually we + # didn't see any function name on this line, so this is likely a + # multi-line parameter list. Try a bit harder to catch this case. + for i in xrange(2): + if (linenum > i and + Search(whitelisted_functions, clean_lines.elided[linenum - i - 1])): + check_params = False + break + + if check_params: + decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body + for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls): + if not Match(_RE_PATTERN_CONST_REF_PARAM, parameter): + error(filename, linenum, 'runtime/references', 2, + 'Is this a non-const reference? ' + 'If so, make const or use a pointer: ' + + ReplaceAll(' *<', '<', parameter)) + + +def CheckCStyleCast(filename, linenum, line, raw_line, cast_type, pattern, + error): + """Checks for a C-style cast by looking for the pattern. + + Args: + filename: The name of the current file. + linenum: The number of the line to check. + line: The line of code to check. + raw_line: The raw line of code to check, with comments. + cast_type: The string for the C++ cast to recommend. This is either + reinterpret_cast, static_cast, or const_cast, depending. + pattern: The regular expression used to find C-style casts. + error: The function to call with any errors found. + + Returns: + True if an error was emitted. + False otherwise. + """ + match = Search(pattern, line) + if not match: + return False + + # Exclude lines with sizeof, since sizeof looks like a cast. + sizeof_match = Match(r'.*sizeof\s*$', line[0:match.start(1) - 1]) + if sizeof_match: + return False + + # operator++(int) and operator--(int) + if (line[0:match.start(1) - 1].endswith(' operator++') or + line[0:match.start(1) - 1].endswith(' operator--')): + return False + + # A single unnamed argument for a function tends to look like old + # style cast. If we see those, don't issue warnings for deprecated + # casts, instead issue warnings for unnamed arguments where + # appropriate. + # + # These are things that we want warnings for, since the style guide + # explicitly require all parameters to be named: + # Function(int); + # Function(int) { + # ConstMember(int) const; + # ConstMember(int) const { + # ExceptionMember(int) throw (...); + # ExceptionMember(int) throw (...) { + # PureVirtual(int) = 0; + # + # These are functions of some sort, where the compiler would be fine + # if they had named parameters, but people often omit those + # identifiers to reduce clutter: + # (FunctionPointer)(int); + # (FunctionPointer)(int) = value; + # Function((function_pointer_arg)(int)) + # ; + # <(FunctionPointerTemplateArgument)(int)>; + remainder = line[match.end(0):] + if Match(r'^\s*(?:;|const\b|throw\b|=|>|\{|\))', remainder): + # Looks like an unnamed parameter. + + # Don't warn on any kind of template arguments. + if Match(r'^\s*>', remainder): + return False + + # Don't warn on assignments to function pointers, but keep warnings for + # unnamed parameters to pure virtual functions. Note that this pattern + # will also pass on assignments of "0" to function pointers, but the + # preferred values for those would be "nullptr" or "NULL". + matched_zero = Match(r'^\s=\s*(\S+)\s*;', remainder) + if matched_zero and matched_zero.group(1) != '0': + return False + + # Don't warn on function pointer declarations. For this we need + # to check what came before the "(type)" string. + if Match(r'.*\)\s*$', line[0:match.start(0)]): + return False + + # Don't warn if the parameter is named with block comments, e.g.: + # Function(int /*unused_param*/); + if '/*' in raw_line: + return False + + # Passed all filters, issue warning here. + error(filename, linenum, 'readability/function', 3, + 'All parameters should be named in a function') + return True + + # At this point, all that should be left is actual casts. + error(filename, linenum, 'readability/casting', 4, + 'Using C-style cast. Use %s<%s>(...) instead' % + (cast_type, match.group(1))) + + return True + + +_HEADERS_CONTAINING_TEMPLATES = ( + ('', ('deque',)), + ('', ('unary_function', 'binary_function', + 'plus', 'minus', 'multiplies', 'divides', 'modulus', + 'negate', + 'equal_to', 'not_equal_to', 'greater', 'less', + 'greater_equal', 'less_equal', + 'logical_and', 'logical_or', 'logical_not', + 'unary_negate', 'not1', 'binary_negate', 'not2', + 'bind1st', 'bind2nd', + 'pointer_to_unary_function', + 'pointer_to_binary_function', + 'ptr_fun', + 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t', + 'mem_fun_ref_t', + 'const_mem_fun_t', 'const_mem_fun1_t', + 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t', + 'mem_fun_ref', + )), + ('', ('numeric_limits',)), + ('', ('list',)), + ('', ('map', 'multimap',)), + ('', ('allocator',)), + ('', ('queue', 'priority_queue',)), + ('', ('set', 'multiset',)), + ('', ('stack',)), + ('', ('char_traits', 'basic_string',)), + ('', ('pair',)), + ('', ('vector',)), + + # gcc extensions. + # Note: std::hash is their hash, ::hash is our hash + ('', ('hash_map', 'hash_multimap',)), + ('', ('hash_set', 'hash_multiset',)), + ('', ('slist',)), + ) + +_RE_PATTERN_STRING = re.compile(r'\bstring\b') + +_re_pattern_algorithm_header = [] +for _template in ('copy', 'max', 'min', 'min_element', 'sort', 'swap', + 'transform'): + # Match max(..., ...), max(..., ...), but not foo->max, foo.max or + # type::max(). + _re_pattern_algorithm_header.append( + (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), + _template, + '')) + +_re_pattern_templates = [] +for _header, _templates in _HEADERS_CONTAINING_TEMPLATES: + for _template in _templates: + _re_pattern_templates.append( + (re.compile(r'(\<|\b)' + _template + r'\s*\<'), + _template + '<>', + _header)) + + +def FilesBelongToSameModule(filename_cc, filename_h): + """Check if these two filenames belong to the same module. + + The concept of a 'module' here is a as follows: + foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the + same 'module' if they are in the same directory. + some/path/public/xyzzy and some/path/internal/xyzzy are also considered + to belong to the same module here. + + If the filename_cc contains a longer path than the filename_h, for example, + '/absolute/path/to/base/sysinfo.cc', and this file would include + 'base/sysinfo.h', this function also produces the prefix needed to open the + header. This is used by the caller of this function to more robustly open the + header file. We don't have access to the real include paths in this context, + so we need this guesswork here. + + Known bugs: tools/base/bar.cc and base/bar.h belong to the same module + according to this implementation. Because of this, this function gives + some false positives. This should be sufficiently rare in practice. + + Args: + filename_cc: is the path for the .cc file + filename_h: is the path for the header path + + Returns: + Tuple with a bool and a string: + bool: True if filename_cc and filename_h belong to the same module. + string: the additional prefix needed to open the header file. + """ + + if not filename_cc.endswith('.cc'): + return (False, '') + filename_cc = filename_cc[:-len('.cc')] + if filename_cc.endswith('_unittest'): + filename_cc = filename_cc[:-len('_unittest')] + elif filename_cc.endswith('_test'): + filename_cc = filename_cc[:-len('_test')] + filename_cc = filename_cc.replace('/public/', '/') + filename_cc = filename_cc.replace('/internal/', '/') + + if not filename_h.endswith('.h'): + return (False, '') + filename_h = filename_h[:-len('.h')] + if filename_h.endswith('-inl'): + filename_h = filename_h[:-len('-inl')] + filename_h = filename_h.replace('/public/', '/') + filename_h = filename_h.replace('/internal/', '/') + + files_belong_to_same_module = filename_cc.endswith(filename_h) + common_path = '' + if files_belong_to_same_module: + common_path = filename_cc[:-len(filename_h)] + return files_belong_to_same_module, common_path + + +def UpdateIncludeState(filename, include_state, io=codecs): + """Fill up the include_state with new includes found from the file. + + Args: + filename: the name of the header to read. + include_state: an _IncludeState instance in which the headers are inserted. + io: The io factory to use to read the file. Provided for testability. + + Returns: + True if a header was succesfully added. False otherwise. + """ + headerfile = None + try: + headerfile = io.open(filename, 'r', 'utf8', 'replace') + except IOError: + return False + linenum = 0 + for line in headerfile: + linenum += 1 + clean_line = CleanseComments(line) + match = _RE_PATTERN_INCLUDE.search(clean_line) + if match: + include = match.group(2) + # The value formatting is cute, but not really used right now. + # What matters here is that the key is in include_state. + include_state.setdefault(include, '%s:%d' % (filename, linenum)) + return True + + +def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, + io=codecs): + """Reports for missing stl includes. + + This function will output warnings to make sure you are including the headers + necessary for the stl containers and functions that you use. We only give one + reason to include a header. For example, if you use both equal_to<> and + less<> in a .h file, only one (the latter in the file) of these will be + reported as a reason to include the . + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + include_state: An _IncludeState instance. + error: The function to call with any errors found. + io: The IO factory to use to read the header file. Provided for unittest + injection. + """ + required = {} # A map of header name to linenumber and the template entity. + # Example of required: { '': (1219, 'less<>') } + + for linenum in xrange(clean_lines.NumLines()): + line = clean_lines.elided[linenum] + if not line or line[0] == '#': + continue + + # String is special -- it is a non-templatized type in STL. + matched = _RE_PATTERN_STRING.search(line) + if matched: + # Don't warn about strings in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[''] = (linenum, 'string') + + for pattern, template, header in _re_pattern_algorithm_header: + if pattern.search(line): + required[header] = (linenum, template) + + # The following function is just a speed up, no semantics are changed. + if not '<' in line: # Reduces the cpu time usage by skipping lines. + continue + + for pattern, template, header in _re_pattern_templates: + if pattern.search(line): + required[header] = (linenum, template) + + # The policy is that if you #include something in foo.h you don't need to + # include it again in foo.cc. Here, we will look at possible includes. + # Let's copy the include_state so it is only messed up within this function. + include_state = include_state.copy() + + # Did we find the header for this file (if any) and succesfully load it? + header_found = False + + # Use the absolute path so that matching works properly. + abs_filename = FileInfo(filename).FullName() + + # For Emacs's flymake. + # If cpplint is invoked from Emacs's flymake, a temporary file is generated + # by flymake and that file name might end with '_flymake.cc'. In that case, + # restore original file name here so that the corresponding header file can be + # found. + # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h' + # instead of 'foo_flymake.h' + abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename) + + # include_state is modified during iteration, so we iterate over a copy of + # the keys. + header_keys = include_state.keys() + for header in header_keys: + (same_module, common_path) = FilesBelongToSameModule(abs_filename, header) + fullpath = common_path + header + if same_module and UpdateIncludeState(fullpath, include_state, io): + header_found = True + + # If we can't find the header file for a .cc, assume it's because we don't + # know where to look. In that case we'll give up as we're not sure they + # didn't include it in the .h file. + # TODO(unknown): Do a better job of finding .h files so we are confident that + # not having the .h file means there isn't one. + if filename.endswith('.cc') and not header_found: + return + + # All the lines have been processed, report the errors found. + for required_header_unstripped in required: + template = required[required_header_unstripped][1] + if required_header_unstripped.strip('<>"') not in include_state: + error(filename, required[required_header_unstripped][0], + 'build/include_what_you_use', 4, + 'Add #include ' + required_header_unstripped + ' for ' + template) + + +_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<') + + +def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error): + """Check that make_pair's template arguments are deduced. + + G++ 4.6 in C++0x mode fails badly if make_pair's template arguments are + specified explicitly, and such use isn't intended in any case. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line) + if match: + error(filename, linenum, 'build/explicit_make_pair', + 4, # 4 = high confidence + 'For C++11-compatibility, omit template arguments from make_pair' + ' OR use pair directly OR if appropriate, construct a pair directly') + + +def ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions=[]): + """Processes a single line in the file. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + clean_lines: An array of strings, each representing a line of the file, + with comments stripped. + line: Number of line being processed. + include_state: An _IncludeState instance in which the headers are inserted. + function_state: A _FunctionState instance which counts function lines, etc. + nesting_state: A _NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[line], line, error) + nesting_state.Update(filename, clean_lines, line, error) + if nesting_state.stack and nesting_state.stack[-1].inline_asm != _NO_ASM: + return + CheckForFunctionLengths(filename, clean_lines, line, function_state, error) + CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error) + CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error) + CheckLanguage(filename, clean_lines, line, file_extension, include_state, + nesting_state, error) + CheckForNonConstReference(filename, clean_lines, line, nesting_state, error) + CheckForNonStandardConstructs(filename, clean_lines, line, + nesting_state, error) + CheckVlogArguments(filename, clean_lines, line, error) + CheckCaffeAlternatives(filename, clean_lines, line, error) + CheckCaffeDataLayerSetUp(filename, clean_lines, line, error) + CheckCaffeRandom(filename, clean_lines, line, error) + CheckPosixThreading(filename, clean_lines, line, error) + CheckInvalidIncrement(filename, clean_lines, line, error) + CheckMakePairUsesDeduction(filename, clean_lines, line, error) + for check_fn in extra_check_functions: + check_fn(filename, clean_lines, line, error) + +def ProcessFileData(filename, file_extension, lines, error, + extra_check_functions=[]): + """Performs lint checks and reports any errors to the given error function. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + lines = (['// marker so line numbers and indices both start at 1'] + lines + + ['// marker so line numbers end in a known way']) + + include_state = _IncludeState() + function_state = _FunctionState() + nesting_state = _NestingState() + + ResetNolintSuppressions() + + CheckForCopyright(filename, lines, error) + + if file_extension == 'h': + CheckForHeaderGuard(filename, lines, error) + + RemoveMultiLineComments(filename, lines, error) + clean_lines = CleansedLines(lines) + for line in xrange(clean_lines.NumLines()): + ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions) + nesting_state.CheckCompletedBlocks(filename, error) + + CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error) + + # We check here rather than inside ProcessLine so that we see raw + # lines rather than "cleaned" lines. + CheckForBadCharacters(filename, lines, error) + + CheckForNewlineAtEOF(filename, lines, error) + +def ProcessFile(filename, vlevel, extra_check_functions=[]): + """Does google-lint on a single file. + + Args: + filename: The name of the file to parse. + + vlevel: The level of errors to report. Every error of confidence + >= verbose_level will be reported. 0 is a good default. + + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + + _SetVerboseLevel(vlevel) + + try: + # Support the UNIX convention of using "-" for stdin. Note that + # we are not opening the file with universal newline support + # (which codecs doesn't support anyway), so the resulting lines do + # contain trailing '\r' characters if we are reading a file that + # has CRLF endings. + # If after the split a trailing '\r' is present, it is removed + # below. If it is not expected to be present (i.e. os.linesep != + # '\r\n' as in Windows), a warning is issued below if this file + # is processed. + + if filename == '-': + lines = codecs.StreamReaderWriter(sys.stdin, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace').read().split('\n') + else: + lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n') + + carriage_return_found = False + # Remove trailing '\r'. + for linenum in range(len(lines)): + if lines[linenum].endswith('\r'): + lines[linenum] = lines[linenum].rstrip('\r') + carriage_return_found = True + + except IOError: + sys.stderr.write( + "Skipping input '%s': Can't open for reading\n" % filename) + return + + # Note, if no dot is found, this will give the entire filename as the ext. + file_extension = filename[filename.rfind('.') + 1:] + + # When reading from stdin, the extension is unknown, so no cpplint tests + # should rely on the extension. + if filename != '-' and file_extension not in _valid_extensions: + sys.stderr.write('Ignoring %s; not a valid file name ' + '(%s)\n' % (filename, ', '.join(_valid_extensions))) + else: + ProcessFileData(filename, file_extension, lines, Error, + extra_check_functions) + if carriage_return_found and os.linesep != '\r\n': + # Use 0 for linenum since outputting only one error for potentially + # several lines. + Error(filename, 0, 'whitespace/newline', 1, + 'One or more unexpected \\r (^M) found;' + 'better to use only a \\n') + + sys.stderr.write('Done processing %s\n' % filename) + + +def PrintUsage(message): + """Prints a brief usage string and exits, optionally with an error message. + + Args: + message: The optional error message. + """ + sys.stderr.write(_USAGE) + if message: + sys.exit('\nFATAL ERROR: ' + message) + else: + sys.exit(1) + + +def PrintCategories(): + """Prints a list of all the error-categories used by error messages. + + These are the categories used to filter messages via --filter. + """ + sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES)) + sys.exit(0) + + +def ParseArguments(args): + """Parses the command line arguments. + + This may set the output format and verbosity level as side-effects. + + Args: + args: The command line arguments: + + Returns: + The list of filenames to lint. + """ + try: + (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=', + 'counting=', + 'filter=', + 'root=', + 'linelength=', + 'extensions=']) + except getopt.GetoptError: + PrintUsage('Invalid arguments.') + + verbosity = _VerboseLevel() + output_format = _OutputFormat() + filters = '' + counting_style = '' + + for (opt, val) in opts: + if opt == '--help': + PrintUsage(None) + elif opt == '--output': + if val not in ('emacs', 'vs7', 'eclipse'): + PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.') + output_format = val + elif opt == '--verbose': + verbosity = int(val) + elif opt == '--filter': + filters = val + if not filters: + PrintCategories() + elif opt == '--counting': + if val not in ('total', 'toplevel', 'detailed'): + PrintUsage('Valid counting options are total, toplevel, and detailed') + counting_style = val + elif opt == '--root': + global _root + _root = val + elif opt == '--linelength': + global _line_length + try: + _line_length = int(val) + except ValueError: + PrintUsage('Line length must be digits.') + elif opt == '--extensions': + global _valid_extensions + try: + _valid_extensions = set(val.split(',')) + except ValueError: + PrintUsage('Extensions must be comma seperated list.') + + if not filenames: + PrintUsage('No files were specified.') + + _SetOutputFormat(output_format) + _SetVerboseLevel(verbosity) + _SetFilters(filters) + _SetCountingStyle(counting_style) + + return filenames + + +def main(): + filenames = ParseArguments(sys.argv[1:]) + + # Change stderr to write with replacement characters so we don't die + # if we try to print something containing non-ASCII characters. + sys.stderr = codecs.StreamReaderWriter(sys.stderr, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace') + + _cpplint_state.ResetErrorCounts() + for filename in filenames: + ProcessFile(filename, _cpplint_state.verbose_level) + _cpplint_state.PrintErrorCounts() + + sys.exit(_cpplint_state.error_count > 0) + + +if __name__ == '__main__': + main() diff --git a/caffe-crfrnn/scripts/deploy_docs.sh b/caffe-crfrnn/scripts/deploy_docs.sh new file mode 100755 index 00000000..fdf97f71 --- /dev/null +++ b/caffe-crfrnn/scripts/deploy_docs.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Publish documentation to the gh-pages site. + +# The remote for pushing the docs (defaults to origin). +# This is where you will submit the PR to BVLC:gh-pages from. +REMOTE=${1:-origin} + +echo "Generating docs and pushing to $REMOTE:gh-pages..." +echo "To build and view docs when not on master, simply do 'jekyll serve -s docs'." +echo + +REMOTE_URL=`git config --get remote.${REMOTE}.url` +BRANCH=`git rev-parse --abbrev-ref HEAD` +MSG=`git log --oneline -1` + +if [[ $BRANCH = 'master' ]]; then + # Find the docs dir, no matter where the script is called + DIR="$( cd "$(dirname "$0")" ; pwd -P )" + DOCS_SITE_DIR=$DIR/../docs/_site + + # Make sure that docs/_site tracks remote:gh-pages. + # If not, then we make a new repo and check out just that branch. + mkdir -p $DOCS_SITE_DIR + cd $DOCS_SITE_DIR + SITE_REMOTE_URL=`git config --get remote.${REMOTE}.url` + SITE_BRANCH=`git rev-parse --abbrev-ref HEAD` + + echo $SITE_REMOTE_URL + echo $SITE_BRANCH + echo `pwd` + + if [[ ( $SITE_REMOTE_URL = $REMOTE_URL ) && ( $SITE_BRANCH = 'gh-pages' ) ]]; then + echo "Confirmed that docs/_site has same remote as main repo, and is on gh-pages." + else + echo "Checking out $REMOTE:gh-pages into docs/_site (will take a little time)." + git init . + git remote add -t gh-pages -f $REMOTE $REMOTE_URL + git checkout gh-pages + fi + + echo "Building the site into docs/_site, and committing the changes." + jekyll build -s .. -d . + git add --all . + git commit -m "$MSG" + git push $REMOTE gh-pages + + echo "All done!" + cd ../.. +else echo "You must run this deployment script from the 'master' branch." +fi diff --git a/caffe-crfrnn/scripts/download_model_binary.py b/caffe-crfrnn/scripts/download_model_binary.py new file mode 100755 index 00000000..48e9015f --- /dev/null +++ b/caffe-crfrnn/scripts/download_model_binary.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +import os +import sys +import time +import yaml +import urllib +import hashlib +import argparse + +required_keys = ['caffemodel', 'caffemodel_url', 'sha1'] + + +def reporthook(count, block_size, total_size): + """ + From http://blog.moleculea.com/2012/10/04/urlretrieve-progres-indicator/ + """ + global start_time + if count == 0: + start_time = time.time() + return + duration = time.time() - start_time + progress_size = int(count * block_size) + speed = int(progress_size / (1024 * duration)) + percent = int(count * block_size * 100 / total_size) + sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % + (percent, progress_size / (1024 * 1024), speed, duration)) + sys.stdout.flush() + + +def parse_readme_frontmatter(dirname): + readme_filename = os.path.join(dirname, 'readme.md') + with open(readme_filename) as f: + lines = [line.strip() for line in f.readlines()] + top = lines.index('---') + bottom = lines[top + 1:].index('---') + frontmatter = yaml.load('\n'.join(lines[top + 1:bottom])) + assert all(key in frontmatter for key in required_keys) + return dirname, frontmatter + + +def valid_dirname(dirname): + try: + return parse_readme_frontmatter(dirname) + except Exception as e: + print('ERROR: {}'.format(e)) + raise argparse.ArgumentTypeError( + 'Must be valid Caffe model directory with a correct readme.md') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Download trained model binary.') + parser.add_argument('dirname', type=valid_dirname) + args = parser.parse_args() + + # A tiny hack: the dirname validator also returns readme YAML frontmatter. + dirname = args.dirname[0] + frontmatter = args.dirname[1] + model_filename = os.path.join(dirname, frontmatter['caffemodel']) + + # Closure-d function for checking SHA1. + def model_checks_out(filename=model_filename, sha1=frontmatter['sha1']): + with open(filename, 'r') as f: + return hashlib.sha1(f.read()).hexdigest() == sha1 + + # Check if model exists. + if os.path.exists(model_filename) and model_checks_out(): + print("Model already exists.") + sys.exit(0) + + # Download and verify model. + urllib.urlretrieve( + frontmatter['caffemodel_url'], model_filename, reporthook) + if not model_checks_out(): + print('ERROR: model did not download correctly! Run this again.') + sys.exit(1) diff --git a/caffe-crfrnn/scripts/download_model_from_gist.sh b/caffe-crfrnn/scripts/download_model_from_gist.sh new file mode 100755 index 00000000..a1dccf78 --- /dev/null +++ b/caffe-crfrnn/scripts/download_model_from_gist.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env sh + +GIST=$1 +DIRNAME=${2:-./models} + +if [ -z $GIST ]; then + echo "usage: download_model_from_gist.sh " + exit +fi + +GIST_DIR=$(echo $GIST | tr '/' '-') +MODEL_DIR="$DIRNAME/$GIST_DIR" + +if [ -d $MODEL_DIR ]; then + echo "$MODEL_DIR already exists! Please make sure you're not overwriting anything important!" + exit +fi + +echo "Downloading Caffe model info to $MODEL_DIR ..." +mkdir -p $MODEL_DIR +wget https://gist.github.com/$GIST/download -O $MODEL_DIR/gist.tar.gz +tar xzf $MODEL_DIR/gist.tar.gz --directory=$MODEL_DIR --strip-components=1 +rm $MODEL_DIR/gist.tar.gz +echo "Done" diff --git a/caffe-crfrnn/scripts/gather_examples.sh b/caffe-crfrnn/scripts/gather_examples.sh new file mode 100755 index 00000000..3fc72606 --- /dev/null +++ b/caffe-crfrnn/scripts/gather_examples.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Assemble documentation for the project into one directory via symbolic links. + +# Find the docs dir, no matter where the script is called +ROOT_DIR="$( cd "$(dirname "$0")"/.. ; pwd -P )" +cd $ROOT_DIR + +# Gather docs from examples/**/readme.md +GATHERED_DIR=docs/gathered +rm -r $GATHERED_DIR +mkdir $GATHERED_DIR +for README_FILENAME in $(find examples -iname "readme.md"); do + # Only use file if it is to be included in docs. + if grep -Fxq "include_in_docs: true" $README_FILENAME; then + # Make link to readme.md in docs/gathered/. + # Since everything is called readme.md, rename it by its dirname. + README_DIRNAME=`dirname $README_FILENAME` + DOCS_FILENAME=$GATHERED_DIR/$README_DIRNAME.md + mkdir -p `dirname $DOCS_FILENAME` + ln -s $ROOT_DIR/$README_FILENAME $DOCS_FILENAME + fi +done + +# Gather docs from examples/*.ipynb and add YAML front-matter. +for NOTEBOOK_FILENAME in $(find examples -depth -iname "*.ipynb"); do + DOCS_FILENAME=$GATHERED_DIR/$NOTEBOOK_FILENAME + mkdir -p `dirname $DOCS_FILENAME` + python scripts/copy_notebook.py $NOTEBOOK_FILENAME $DOCS_FILENAME +done diff --git a/caffe-crfrnn/scripts/travis/travis_build_and_test.sh b/caffe-crfrnn/scripts/travis/travis_build_and_test.sh new file mode 100755 index 00000000..53c6c341 --- /dev/null +++ b/caffe-crfrnn/scripts/travis/travis_build_and_test.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Script called by Travis to do a CPU-only build of and test Caffe. + +set -e +MAKE="make --jobs=$NUM_THREADS --keep-going" + +if $WITH_CMAKE; then + mkdir build + cd build + cmake -DBUILD_PYTHON=ON -DBUILD_EXAMPLES=ON -DCMAKE_BUILD_TYPE=Release -DCPU_ONLY=ON .. + $MAKE + if ! $WITH_CUDA; then + $MAKE runtest + $MAKE lint + fi + $MAKE clean + cd - +else + if ! $WITH_CUDA; then + export CPU_ONLY=1 + fi + $MAKE all test pycaffe warn lint || true + if ! $WITH_CUDA; then + $MAKE runtest + fi + $MAKE all + $MAKE test + $MAKE pycaffe + $MAKE pytest + $MAKE warn + if ! $WITH_CUDA; then + $MAKE lint + fi +fi diff --git a/caffe-crfrnn/scripts/travis/travis_install.sh b/caffe-crfrnn/scripts/travis/travis_install.sh new file mode 100755 index 00000000..82f386cf --- /dev/null +++ b/caffe-crfrnn/scripts/travis/travis_install.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# This script must be run with sudo. + +set -e + +MAKE="make --jobs=$NUM_THREADS" + +# Install apt packages where the Ubuntu 12.04 default and ppa works for Caffe + +# This ppa is for gflags and glog +add-apt-repository -y ppa:tuleu/precise-backports +apt-get -y update +apt-get install \ + wget git curl \ + python-dev python-numpy \ + libleveldb-dev libsnappy-dev libopencv-dev \ + libboost-dev libboost-system-dev libboost-python-dev libboost-thread-dev \ + libprotobuf-dev protobuf-compiler \ + libatlas-dev libatlas-base-dev \ + libhdf5-serial-dev libgflags-dev libgoogle-glog-dev \ + bc + +# Add a special apt-repository to install CMake 2.8.9 for CMake Caffe build, +# if needed. By default, Aptitude in Ubuntu 12.04 installs CMake 2.8.7, but +# Caffe requires a minimum CMake version of 2.8.8. +if $WITH_CMAKE; then + add-apt-repository -y ppa:ubuntu-sdk-team/ppa + apt-get -y update + apt-get -y install cmake +fi + +# Install CUDA, if needed +if $WITH_CUDA; then + CUDA_URL=http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1204/x86_64/cuda-repo-ubuntu1204_6.5-14_amd64.deb + CUDA_FILE=/tmp/cuda_install.deb + curl $CUDA_URL -o $CUDA_FILE + dpkg -i $CUDA_FILE + rm -f $CUDA_FILE + apt-get -y update + # Install the minimal CUDA subpackages required to test Caffe build. + # For a full CUDA installation, add 'cuda' to the list of packages. + apt-get -y install cuda-core-6-5 cuda-cublas-6-5 cuda-cublas-dev-6-5 cuda-cudart-6-5 cuda-cudart-dev-6-5 cuda-curand-6-5 cuda-curand-dev-6-5 + # Create CUDA symlink at /usr/local/cuda + # (This would normally be created by the CUDA installer, but we create it + # manually since we did a partial installation.) + ln -s /usr/local/cuda-6.5 /usr/local/cuda +fi + +# Install LMDB +LMDB_URL=ftp://ftp.openldap.org/pub/OpenLDAP/openldap-release/openldap-2.4.39.tgz +LMDB_FILE=/tmp/openldap.tgz +pushd . +curl $LMDB_URL -o $LMDB_FILE +tar -C /tmp -xzvf $LMDB_FILE +cd /tmp/openldap*/libraries/liblmdb/ +$MAKE +$MAKE install +popd +rm -f $LMDB_FILE + +# Install the Python runtime dependencies via miniconda (this is much faster +# than using pip for everything). +wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh +chmod +x miniconda.sh +./miniconda.sh -b +export PATH=/home/travis/miniconda/bin:$PATH +conda update --yes conda +conda install --yes numpy scipy matplotlib scikit-image pip +pip install protobuf +rm /home/travis/miniconda/lib/libm.* diff --git a/caffe-crfrnn/scripts/travis/travis_setup_makefile_config.sh b/caffe-crfrnn/scripts/travis/travis_setup_makefile_config.sh new file mode 100755 index 00000000..e8d85f9b --- /dev/null +++ b/caffe-crfrnn/scripts/travis/travis_setup_makefile_config.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e + +mv Makefile.config.example Makefile.config + +if $WITH_CUDA; then + # Only generate compute_50. + GENCODE="-gencode arch=compute_50,code=sm_50" + GENCODE="$GENCODE -gencode arch=compute_50,code=compute_50" + echo "CUDA_ARCH := $GENCODE" >> Makefile.config +fi + +cat << 'EOF' >> Makefile.config +ANACONDA_HOME := $(HOME)/miniconda +PYTHON_INCLUDE := $(ANACONDA_HOME)/include \ + $(ANACONDA_HOME)/include/python2.7 \ + $(ANACONDA_HOME)/lib/python2.7/site-packages/numpy/core/include +PYTHON_LIB := $(ANACONDA_HOME)/lib +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include +LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib +EOF diff --git a/caffe-crfrnn/scripts/upload_model_to_gist.sh b/caffe-crfrnn/scripts/upload_model_to_gist.sh new file mode 100755 index 00000000..3c4fd64e --- /dev/null +++ b/caffe-crfrnn/scripts/upload_model_to_gist.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Check for valid directory +DIRNAME=$1 +if [ ! -f $DIRNAME/readme.md ]; then + echo "usage: upload_model_to_gist.sh " + echo " /readme.md must exist" +fi +cd $DIRNAME +FILES=`find . -maxdepth 1 -type f ! -name "*.caffemodel*" | xargs echo` + +# Check for gist tool. +gist -v >/dev/null 2>&1 || { echo >&2 "I require 'gist' but it's not installed. Do 'gem install gist'."; exit 1; } + +NAME=`sed -n 's/^name:[[:space:]]*//p' readme.md` +if [ -z "$NAME" ]; then + echo " /readme.md must contain name field in the front-matter." +fi + +GIST=`sed -n 's/^gist_id:[[:space:]]*//p' readme.md` +if [ -z "$GIST" ]; then + echo "Uploading new Gist" + gist -p -d "$NAME" $FILES +else + echo "Updating existing Gist, id $GIST" + gist -u $GIST -d "$NAME" $FILES +fi + +RESULT=$? +if [ $RESULT -eq 0 ]; then + echo "You've uploaded your model!" + echo "Don't forget to add the gist_id field to your /readme.md now!" + echo "Run the command again after you do that, to make sure the Gist id propagates." + echo "" + echo "And do share your model over at https://github.com/BVLC/caffe/wiki/Model-Zoo" +else + echo "Something went wrong!" +fi diff --git a/caffe-crfrnn/src/caffe/CMakeLists.txt b/caffe-crfrnn/src/caffe/CMakeLists.txt new file mode 100644 index 00000000..40e6c11f --- /dev/null +++ b/caffe-crfrnn/src/caffe/CMakeLists.txt @@ -0,0 +1,36 @@ +# generate protobuf sources +file(GLOB proto_files proto/*.proto) +caffe_protobuf_generate_cpp_py(${proto_gen_folder} proto_srcs proto_hdrs proto_python ${proto_files}) + +# include python files either to force generation +add_library(proto STATIC ${proto_hdrs} ${proto_srcs} ${proto_python}) +set(Caffe_LINKER_LIBS proto ${Caffe_LINKER_LIBS}) # note, crucial to prepend! +caffe_default_properties(proto) + +# --[ Caffe library + +# creates 'test_srcs', 'srcs', 'test_cuda', 'cuda' lists +caffe_pickup_caffe_sources(${PROJECT_SOURCE_DIR}) + +if(HAVE_CUDA) + caffe_cuda_compile(cuda_objs ${cuda}) + list(APPEND srcs ${cuda_objs} ${cuda}) +endif() + +add_library(caffe ${srcs}) +target_link_libraries(caffe proto ${Caffe_LINKER_LIBS}) +caffe_default_properties(caffe) + +# ---[ Tests + add_subdirectory(test) + +# ---[ Install +install(DIRECTORY ${Caffe_INCLUDE_DIR}/caffe DESTINATION include) +install(FILES ${proto_hdrs} DESTINATION include/caffe/proto) +install(TARGETS caffe proto EXPORT CaffeTargets DESTINATION lib) + +file(WRITE ${PROJECT_BINARY_DIR}/__init__.py) +list(APPEND proto_python ${PROJECT_BINARY_DIR}/__init__.py) +install(PROGRAMS ${proto_python} DESTINATION python/caffe/proto) + + diff --git a/caffe-crfrnn/src/caffe/blob.cpp b/caffe-crfrnn/src/caffe/blob.cpp new file mode 100644 index 00000000..cfffc379 --- /dev/null +++ b/caffe-crfrnn/src/caffe/blob.cpp @@ -0,0 +1,283 @@ +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void Blob::Reshape(const int num, const int channels, const int height, + const int width) { + CHECK_GE(num, 0); + CHECK_GE(channels, 0); + CHECK_GE(height, 0); + CHECK_GE(width, 0); + num_ = num; + channels_ = channels; + height_ = height; + width_ = width; + count_ = num_ * channels_ * height_ * width_; + if (count_ > capacity_) { + capacity_ = count_; + data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); + diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); + } +} + +template +void Blob::ReshapeLike(const Blob& other) { + Reshape(other.num(), other.channels(), other.height(), other.width()); +} + +template +Blob::Blob(const int num, const int channels, const int height, + const int width) + // capacity_ must be initialized before calling Reshape + : capacity_(0) { + Reshape(num, channels, height, width); +} + +template +const Dtype* Blob::cpu_data() const { + CHECK(data_); + return (const Dtype*)data_->cpu_data(); +} + +template +void Blob::set_cpu_data(Dtype* data) { + CHECK(data); + data_->set_cpu_data(data); +} + +template +const Dtype* Blob::gpu_data() const { + CHECK(data_); + return (const Dtype*)data_->gpu_data(); +} + +template +const Dtype* Blob::cpu_diff() const { + CHECK(diff_); + return (const Dtype*)diff_->cpu_data(); +} + +template +const Dtype* Blob::gpu_diff() const { + CHECK(diff_); + return (const Dtype*)diff_->gpu_data(); +} + +template +Dtype* Blob::mutable_cpu_data() { + CHECK(data_); + return static_cast(data_->mutable_cpu_data()); +} + +template +Dtype* Blob::mutable_gpu_data() { + CHECK(data_); + return static_cast(data_->mutable_gpu_data()); +} + +template +Dtype* Blob::mutable_cpu_diff() { + CHECK(diff_); + return static_cast(diff_->mutable_cpu_data()); +} + +template +Dtype* Blob::mutable_gpu_diff() { + CHECK(diff_); + return static_cast(diff_->mutable_gpu_data()); +} + +template +void Blob::ShareData(const Blob& other) { + CHECK_EQ(count_, other.count()); + data_ = other.data(); +} + +template +void Blob::ShareDiff(const Blob& other) { + CHECK_EQ(count_, other.count()); + diff_ = other.diff(); +} + +// The "update" method is used for parameter blobs in a Net, which are stored +// as Blob or Blob -- hence we do not define it for +// Blob or Blob. +template <> void Blob::Update() { NOT_IMPLEMENTED; } +template <> void Blob::Update() { NOT_IMPLEMENTED; } + +template +void Blob::Update() { + // We will perform update based on where the data is located. + switch (data_->head()) { + case SyncedMemory::HEAD_AT_CPU: + // perform computation on CPU + caffe_axpy(count_, Dtype(-1), + static_cast(diff_->cpu_data()), + static_cast(data_->mutable_cpu_data())); + break; + case SyncedMemory::HEAD_AT_GPU: + case SyncedMemory::SYNCED: +#ifndef CPU_ONLY + // perform computation on GPU + caffe_gpu_axpy(count_, Dtype(-1), + static_cast(diff_->gpu_data()), + static_cast(data_->mutable_gpu_data())); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Syncedmem not initialized."; + } +} + +template <> unsigned int Blob::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + +template <> int Blob::asum_data() const { + NOT_IMPLEMENTED; + return 0; +} + +template +Dtype Blob::asum_data() const { + if (!data_) { return 0; } + switch (data_->head()) { + case SyncedMemory::HEAD_AT_CPU: + return caffe_cpu_asum(count_, cpu_data()); + case SyncedMemory::HEAD_AT_GPU: + case SyncedMemory::SYNCED: +#ifndef CPU_ONLY + { + Dtype asum; + caffe_gpu_asum(count_, gpu_data(), &asum); + return asum; + } +#else + NO_GPU; +#endif + case SyncedMemory::UNINITIALIZED: + return 0; + default: + LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head(); + } + return 0; +} + +template <> unsigned int Blob::asum_diff() const { + NOT_IMPLEMENTED; + return 0; +} + +template <> int Blob::asum_diff() const { + NOT_IMPLEMENTED; + return 0; +} + +template +Dtype Blob::asum_diff() const { + if (!diff_) { return 0; } + switch (diff_->head()) { + case SyncedMemory::HEAD_AT_CPU: + return caffe_cpu_asum(count_, cpu_diff()); + case SyncedMemory::HEAD_AT_GPU: + case SyncedMemory::SYNCED: +#ifndef CPU_ONLY + { + Dtype asum; + caffe_gpu_asum(count_, gpu_diff(), &asum); + return asum; + } +#else + NO_GPU; +#endif + case SyncedMemory::UNINITIALIZED: + return 0; + default: + LOG(FATAL) << "Unknown SyncedMemory head state: " << diff_->head(); + } + return 0; +} + +template +void Blob::CopyFrom(const Blob& source, bool copy_diff, bool reshape) { + if (num_ != source.num() || channels_ != source.channels() || + height_ != source.height() || width_ != source.width()) { + if (reshape) { + Reshape(source.num(), source.channels(), source.height(), source.width()); + } else { + LOG(FATAL) << "Trying to copy blobs of different sizes."; + } + } + switch (Caffe::mode()) { + case Caffe::GPU: + if (copy_diff) { + caffe_copy(count_, source.gpu_diff(), + static_cast(diff_->mutable_gpu_data())); + } else { + caffe_copy(count_, source.gpu_data(), + static_cast(data_->mutable_gpu_data())); + } + break; + case Caffe::CPU: + if (copy_diff) { + caffe_copy(count_, source.cpu_diff(), + static_cast(diff_->mutable_cpu_data())); + } else { + caffe_copy(count_, source.cpu_data(), + static_cast(data_->mutable_cpu_data())); + } + break; + default: + LOG(FATAL) << "Unknown caffe mode."; + } +} + +template +void Blob::FromProto(const BlobProto& proto) { + Reshape(proto.num(), proto.channels(), proto.height(), proto.width()); + // copy data + Dtype* data_vec = mutable_cpu_data(); + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.data(i); + } + if (proto.diff_size() > 0) { + Dtype* diff_vec = mutable_cpu_diff(); + for (int i = 0; i < count_; ++i) { + diff_vec[i] = proto.diff(i); + } + } +} + +template +void Blob::ToProto(BlobProto* proto, bool write_diff) const { + proto->set_num(num_); + proto->set_channels(channels_); + proto->set_height(height_); + proto->set_width(width_); + proto->clear_data(); + proto->clear_diff(); + const Dtype* data_vec = cpu_data(); + for (int i = 0; i < count_; ++i) { + proto->add_data(data_vec[i]); + } + if (write_diff) { + const Dtype* diff_vec = cpu_diff(); + for (int i = 0; i < count_; ++i) { + proto->add_diff(diff_vec[i]); + } + } +} + +INSTANTIATE_CLASS(Blob); +template class Blob; +template class Blob; + +} // namespace caffe + diff --git a/caffe-crfrnn/src/caffe/common.cpp b/caffe-crfrnn/src/caffe/common.cpp new file mode 100644 index 00000000..834d5694 --- /dev/null +++ b/caffe-crfrnn/src/caffe/common.cpp @@ -0,0 +1,271 @@ +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/util/rng.hpp" + +namespace caffe { + +shared_ptr Caffe::singleton_; + +// random seeding +int64_t cluster_seedgen(void) { + int64_t s, seed, pid; + FILE* f = fopen("/dev/urandom", "rb"); + if (f && fread(&seed, 1, sizeof(seed), f) == sizeof(seed)) { + fclose(f); + return seed; + } + + LOG(INFO) << "System entropy source not available, " + "using fallback algorithm to generate seed instead."; + if (f) + fclose(f); + + pid = getpid(); + s = time(NULL); + seed = abs(((s * 181) * ((pid - 83) * 359)) % 104729); + return seed; +} + + +void GlobalInit(int* pargc, char*** pargv) { + // Google flags. + ::gflags::ParseCommandLineFlags(pargc, pargv, true); + // Google logging. + ::google::InitGoogleLogging(*(pargv)[0]); + // Provide a backtrace on segfault. + ::google::InstallFailureSignalHandler(); +} + +#ifdef CPU_ONLY // CPU-only Caffe. + +Caffe::Caffe() + : random_generator_(), mode_(Caffe::CPU), phase_(Caffe::TRAIN) { } + +Caffe::~Caffe() { } + +void Caffe::set_random_seed(const unsigned int seed) { + // RNG seed + Get().random_generator_.reset(new RNG(seed)); +} + +void Caffe::SetDevice(const int device_id) { + NO_GPU; +} + +void Caffe::DeviceQuery() { + NO_GPU; +} + + +class Caffe::RNG::Generator { + public: + Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {} + explicit Generator(unsigned int seed) : rng_(new caffe::rng_t(seed)) {} + caffe::rng_t* rng() { return rng_.get(); } + private: + shared_ptr rng_; +}; + +Caffe::RNG::RNG() : generator_(new Generator()) { } + +Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { } + +Caffe::RNG& Caffe::RNG::operator=(const RNG& other) { + generator_ = other.generator_; + return *this; +} + +void* Caffe::RNG::generator() { + return static_cast(generator_->rng()); +} + +#else // Normal GPU + CPU Caffe. + +Caffe::Caffe() + : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), + mode_(Caffe::CPU), phase_(Caffe::TRAIN) { + // Try to create a cublas handler, and report an error if failed (but we will + // keep the program running as one might just want to run CPU code). + if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { + LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available."; + } + // Try to create a curand handler. + if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT) + != CURAND_STATUS_SUCCESS || + curandSetPseudoRandomGeneratorSeed(curand_generator_, cluster_seedgen()) + != CURAND_STATUS_SUCCESS) { + LOG(ERROR) << "Cannot create Curand generator. Curand won't be available."; + } +} + +Caffe::~Caffe() { + if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_)); + if (curand_generator_) { + CURAND_CHECK(curandDestroyGenerator(curand_generator_)); + } +} + +void Caffe::set_random_seed(const unsigned int seed) { + // Curand seed + static bool g_curand_availability_logged = false; + if (Get().curand_generator_) { + CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator(), + seed)); + CURAND_CHECK(curandSetGeneratorOffset(curand_generator(), 0)); + } else { + if (!g_curand_availability_logged) { + LOG(ERROR) << + "Curand not available. Skipping setting the curand seed."; + g_curand_availability_logged = true; + } + } + // RNG seed + Get().random_generator_.reset(new RNG(seed)); +} + +void Caffe::SetDevice(const int device_id) { + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + if (current_device == device_id) { + return; + } + // The call to cudaSetDevice must come before any calls to Get, which + // may perform initialization using the GPU. + CUDA_CHECK(cudaSetDevice(device_id)); + if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_)); + if (Get().curand_generator_) { + CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_)); + } + CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_)); + CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_, + CURAND_RNG_PSEUDO_DEFAULT)); + CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_, + cluster_seedgen())); +} + +void Caffe::DeviceQuery() { + cudaDeviceProp prop; + int device; + if (cudaSuccess != cudaGetDevice(&device)) { + printf("No cuda device present.\n"); + return; + } + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + LOG(INFO) << "Device id: " << device; + LOG(INFO) << "Major revision number: " << prop.major; + LOG(INFO) << "Minor revision number: " << prop.minor; + LOG(INFO) << "Name: " << prop.name; + LOG(INFO) << "Total global memory: " << prop.totalGlobalMem; + LOG(INFO) << "Total shared memory per block: " << prop.sharedMemPerBlock; + LOG(INFO) << "Total registers per block: " << prop.regsPerBlock; + LOG(INFO) << "Warp size: " << prop.warpSize; + LOG(INFO) << "Maximum memory pitch: " << prop.memPitch; + LOG(INFO) << "Maximum threads per block: " << prop.maxThreadsPerBlock; + LOG(INFO) << "Maximum dimension of block: " + << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", " + << prop.maxThreadsDim[2]; + LOG(INFO) << "Maximum dimension of grid: " + << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", " + << prop.maxGridSize[2]; + LOG(INFO) << "Clock rate: " << prop.clockRate; + LOG(INFO) << "Total constant memory: " << prop.totalConstMem; + LOG(INFO) << "Texture alignment: " << prop.textureAlignment; + LOG(INFO) << "Concurrent copy and execution: " + << (prop.deviceOverlap ? "Yes" : "No"); + LOG(INFO) << "Number of multiprocessors: " << prop.multiProcessorCount; + LOG(INFO) << "Kernel execution timeout: " + << (prop.kernelExecTimeoutEnabled ? "Yes" : "No"); + return; +} + + +class Caffe::RNG::Generator { + public: + Generator() : rng_(new caffe::rng_t(cluster_seedgen())) {} + explicit Generator(unsigned int seed) : rng_(new caffe::rng_t(seed)) {} + caffe::rng_t* rng() { return rng_.get(); } + private: + shared_ptr rng_; +}; + +Caffe::RNG::RNG() : generator_(new Generator()) { } + +Caffe::RNG::RNG(unsigned int seed) : generator_(new Generator(seed)) { } + +Caffe::RNG& Caffe::RNG::operator=(const RNG& other) { + generator_.reset(other.generator_.get()); + return *this; +} + +void* Caffe::RNG::generator() { + return static_cast(generator_->rng()); +} + +const char* cublasGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; +#if CUDA_VERSION >= 6000 + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; +#endif +#if CUDA_VERSION >= 6050 + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; +#endif + } + return "Unknown cublas status"; +} + +const char* curandGetErrorString(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + return "Unknown curand status"; +} + +#endif // CPU_ONLY + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/data_transformer.cpp b/caffe-crfrnn/src/caffe/data_transformer.cpp new file mode 100644 index 00000000..023396ce --- /dev/null +++ b/caffe-crfrnn/src/caffe/data_transformer.cpp @@ -0,0 +1,405 @@ +#ifndef OSX +#include +#endif + +#include +#include + +#include "caffe/data_transformer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" + +namespace caffe { + +template +DataTransformer::DataTransformer(const TransformationParameter& param) + : param_(param) { + phase_ = Caffe::phase(); + // check if we want to use mean_file + if (param_.has_mean_file()) { + CHECK_EQ(param_.mean_value_size(), 0) << + "Cannot specify mean_file and mean_value at the same time"; + const string& mean_file = param.mean_file(); + LOG(INFO) << "Loading mean file from" << mean_file; + BlobProto blob_proto; + ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); + data_mean_.FromProto(blob_proto); + } + // check if we want to use mean_value + if (param_.mean_value_size() > 0) { + CHECK(param_.has_mean_file() == false) << + "Cannot specify mean_file and mean_value at the same time"; + for (int c = 0; c < param_.mean_value_size(); ++c) { + mean_values_.push_back(param_.mean_value(c)); + } + } +} + +template +void DataTransformer::Transform(const Datum& datum, + Dtype* transformed_data) { + const string& data = datum.data(); + const int datum_channels = datum.channels(); + const int datum_height = datum.height(); + const int datum_width = datum.width(); + + const int crop_size = param_.crop_size(); + const Dtype scale = param_.scale(); + const bool do_mirror = param_.mirror() && Rand(2); + const bool has_mean_file = param_.has_mean_file(); + const bool has_uint8 = data.size() > 0; + const bool has_mean_values = mean_values_.size() > 0; + + CHECK_GT(datum_channels, 0); + CHECK_GE(datum_height, crop_size); + CHECK_GE(datum_width, crop_size); + + Dtype* mean = NULL; + if (has_mean_file) { + CHECK_EQ(datum_channels, data_mean_.channels()); + CHECK_EQ(datum_height, data_mean_.height()); + CHECK_EQ(datum_width, data_mean_.width()); + mean = data_mean_.mutable_cpu_data(); + } + if (has_mean_values) { + CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) << + "Specify either 1 mean_value or as many as channels: " << datum_channels; + if (datum_channels > 1 && mean_values_.size() == 1) { + // Replicate the mean_value for simplicity + for (int c = 1; c < datum_channels; ++c) { + mean_values_.push_back(mean_values_[0]); + } + } + } + + int height = datum_height; + int width = datum_width; + + int h_off = 0; + int w_off = 0; + if (crop_size) { + height = crop_size; + width = crop_size; + // We only do random crop when we do training. + if (phase_ == Caffe::TRAIN) { + h_off = Rand(datum_height - crop_size + 1); + w_off = Rand(datum_width - crop_size + 1); + } else { + h_off = (datum_height - crop_size) / 2; + w_off = (datum_width - crop_size) / 2; + } + } + + Dtype datum_element; + int top_index, data_index; + for (int c = 0; c < datum_channels; ++c) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + data_index = (c * datum_height + h_off + h) * datum_width + w_off + w; + if (do_mirror) { + top_index = (c * height + h) * width + (width - 1 - w); + } else { + top_index = (c * height + h) * width + w; + } + if (has_uint8) { + datum_element = + static_cast(static_cast(data[data_index])); + } else { + datum_element = datum.float_data(data_index); + } + if (has_mean_file) { + transformed_data[top_index] = + (datum_element - mean[data_index]) * scale; + } else { + if (has_mean_values) { + transformed_data[top_index] = + (datum_element - mean_values_[c]) * scale; + } else { + transformed_data[top_index] = datum_element * scale; + } + } + } + } + } +} + +template +void DataTransformer::Transform(const Datum& datum, + Blob* transformed_blob) { + const int datum_channels = datum.channels(); + const int datum_height = datum.height(); + const int datum_width = datum.width(); + + const int channels = transformed_blob->channels(); + const int height = transformed_blob->height(); + const int width = transformed_blob->width(); + const int num = transformed_blob->num(); + + CHECK_EQ(channels, datum_channels); + CHECK_LE(height, datum_height); + CHECK_LE(width, datum_width); + CHECK_GE(num, 1); + + const int crop_size = param_.crop_size(); + + if (crop_size) { + CHECK_EQ(crop_size, height); + CHECK_EQ(crop_size, width); + } else { + CHECK_EQ(datum_height, height); + CHECK_EQ(datum_width, width); + } + + Dtype* transformed_data = transformed_blob->mutable_cpu_data(); + Transform(datum, transformed_data); +} + +template +void DataTransformer::Transform(const vector & datum_vector, + Blob* transformed_blob) { + const int datum_num = datum_vector.size(); + const int num = transformed_blob->num(); + const int channels = transformed_blob->channels(); + const int height = transformed_blob->height(); + const int width = transformed_blob->width(); + + CHECK_GT(datum_num, 0) << "There is no datum to add"; + CHECK_LE(datum_num, num) << + "The size of datum_vector must be smaller than transformed_blob->num()"; + Blob uni_blob(1, channels, height, width); + for (int item_id = 0; item_id < datum_num; ++item_id) { + int offset = transformed_blob->offset(item_id); + uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); + Transform(datum_vector[item_id], &uni_blob); + } +} + +#ifndef OSX +template +void DataTransformer::Transform(const cv::Mat& cv_img, + Blob* transformed_blob) { + const int img_channels = cv_img.channels(); + const int img_height = cv_img.rows; + const int img_width = cv_img.cols; + + const int channels = transformed_blob->channels(); + const int height = transformed_blob->height(); + const int width = transformed_blob->width(); + const int num = transformed_blob->num(); + + CHECK_EQ(channels, img_channels); + CHECK_LE(height, img_height); + CHECK_LE(width, img_width); + CHECK_GE(num, 1); + + CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; + + const int crop_size = param_.crop_size(); + const Dtype scale = param_.scale(); + const bool do_mirror = param_.mirror() && Rand(2); + const bool has_mean_file = param_.has_mean_file(); + const bool has_mean_values = mean_values_.size() > 0; + + CHECK_GT(img_channels, 0); + CHECK_GE(img_height, crop_size); + CHECK_GE(img_width, crop_size); + + Dtype* mean = NULL; + if (has_mean_file) { + CHECK_EQ(img_channels, data_mean_.channels()); + CHECK_EQ(img_height, data_mean_.height()); + CHECK_EQ(img_width, data_mean_.width()); + mean = data_mean_.mutable_cpu_data(); + } + if (has_mean_values) { + CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) << + "Specify either 1 mean_value or as many as channels: " << img_channels; + if (img_channels > 1 && mean_values_.size() == 1) { + // Replicate the mean_value for simplicity + for (int c = 1; c < img_channels; ++c) { + mean_values_.push_back(mean_values_[0]); + } + } + } + + int h_off = 0; + int w_off = 0; + cv::Mat cv_cropped_img = cv_img; + if (crop_size) { + CHECK_EQ(crop_size, height); + CHECK_EQ(crop_size, width); + // We only do random crop when we do training. + if (phase_ == Caffe::TRAIN) { + h_off = Rand(img_height - crop_size + 1); + w_off = Rand(img_width - crop_size + 1); + } else { + h_off = (img_height - crop_size) / 2; + w_off = (img_width - crop_size) / 2; + } + cv::Rect roi(w_off, h_off, crop_size, crop_size); + cv_cropped_img = cv_img(roi); + } else { + CHECK_EQ(img_height, height); + CHECK_EQ(img_width, width); + } + + CHECK(cv_cropped_img.data); + + Dtype* transformed_data = transformed_blob->mutable_cpu_data(); + int top_index; + for (int h = 0; h < height; ++h) { + const uchar* ptr = cv_cropped_img.ptr(h); + int img_index = 0; + for (int w = 0; w < width; ++w) { + for (int c = 0; c < img_channels; ++c) { + if (do_mirror) { + top_index = (c * height + h) * width + (width - 1 - w); + } else { + top_index = (c * height + h) * width + w; + } + // int top_index = (c * height + h) * width + w; + Dtype pixel = static_cast(ptr[img_index++]); + if (has_mean_file) { + int mean_index = (c * img_height + h_off + h) * img_width + w_off + w; + transformed_data[top_index] = + (pixel - mean[mean_index]) * scale; + } else { + if (has_mean_values) { + transformed_data[top_index] = + (pixel - mean_values_[c]) * scale; + } else { + transformed_data[top_index] = pixel * scale; + } + } + } + } + } +} +#endif + +template +void DataTransformer::Transform(Blob* input_blob, + Blob* transformed_blob) { + const int input_num = input_blob->num(); + const int input_channels = input_blob->channels(); + const int input_height = input_blob->height(); + const int input_width = input_blob->width(); + + const int num = transformed_blob->num(); + const int channels = transformed_blob->channels(); + const int height = transformed_blob->height(); + const int width = transformed_blob->width(); + const int size = transformed_blob->count(); + + CHECK_LE(input_num, num); + CHECK_EQ(input_channels, channels); + CHECK_GE(input_height, height); + CHECK_GE(input_width, width); + + const int crop_size = param_.crop_size(); + const Dtype scale = param_.scale(); + const bool do_mirror = param_.mirror() && Rand(2); + const bool has_mean_file = param_.has_mean_file(); + const bool has_mean_values = mean_values_.size() > 0; + + int h_off = 0; + int w_off = 0; + if (crop_size) { + CHECK_EQ(crop_size, height); + CHECK_EQ(crop_size, width); + // We only do random crop when we do training. + if (phase_ == Caffe::TRAIN) { + h_off = Rand(input_height - crop_size + 1); + w_off = Rand(input_width - crop_size + 1); + } else { + h_off = (input_height - crop_size) / 2; + w_off = (input_width - crop_size) / 2; + } + } else { + CHECK_EQ(input_height, height); + CHECK_EQ(input_width, width); + } + + Dtype* input_data = input_blob->mutable_cpu_data(); + if (has_mean_file) { + CHECK_EQ(input_channels, data_mean_.channels()); + CHECK_EQ(input_height, data_mean_.height()); + CHECK_EQ(input_width, data_mean_.width()); + for (int n = 0; n < input_num; ++n) { + int offset = input_blob->offset(n); + caffe_sub(data_mean_.count(), input_data + offset, + data_mean_.cpu_data(), input_data + offset); + } + } + + if (has_mean_values) { + CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) << + "Specify either 1 mean_value or as many as channels: " << input_channels; + if (mean_values_.size() == 1) { + caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data); + } else { + for (int n = 0; n < input_num; ++n) { + for (int c = 0; c < input_channels; ++c) { + int offset = input_blob->offset(n, c); + caffe_add_scalar(input_height * input_width, -(mean_values_[c]), + input_data + offset); + } + } + } + } + + Dtype* transformed_data = transformed_blob->mutable_cpu_data(); + + for (int n = 0; n < input_num; ++n) { + int top_index_n = n * channels; + int data_index_n = n * channels; + for (int c = 0; c < channels; ++c) { + int top_index_c = (top_index_n + c) * height; + int data_index_c = (data_index_n + c) * input_height + h_off; + for (int h = 0; h < height; ++h) { + int top_index_h = (top_index_c + h) * width; + int data_index_h = (data_index_c + h) * input_width + w_off; + if (do_mirror) { + int top_index_w = top_index_h + width - 1; + for (int w = 0; w < width; ++w) { + transformed_data[top_index_w-w] = input_data[data_index_h + w]; + } + } else { + for (int w = 0; w < width; ++w) { + transformed_data[top_index_h + w] = input_data[data_index_h + w]; + } + } + } + } + } + if (scale != Dtype(1)) { + DLOG(INFO) << "Scale: " << scale; + caffe_scal(size, scale, transformed_data); + } +} + +template +void DataTransformer::InitRand() { + const bool needs_rand = param_.mirror() || + (phase_ == Caffe::TRAIN && param_.crop_size()); + if (needs_rand) { + const unsigned int rng_seed = caffe_rng_rand(); + rng_.reset(new Caffe::RNG(rng_seed)); + } else { + rng_.reset(); + } +} + +template +int DataTransformer::Rand(int n) { + CHECK(rng_); + CHECK_GT(n, 0); + caffe::rng_t* rng = + static_cast(rng_->generator()); + return ((*rng)() % n); +} + +INSTANTIATE_CLASS(DataTransformer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/dataset_factory.cpp b/caffe-crfrnn/src/caffe/dataset_factory.cpp new file mode 100644 index 00000000..3313de3c --- /dev/null +++ b/caffe-crfrnn/src/caffe/dataset_factory.cpp @@ -0,0 +1,50 @@ +#include +#include +#include + +#include "caffe/dataset_factory.hpp" +#include "caffe/leveldb_dataset.hpp" +#include "caffe/lmdb_dataset.hpp" + +namespace caffe { + +template +shared_ptr > DatasetFactory(const DataParameter_DB& type) { + switch (type) { + case DataParameter_DB_LEVELDB: + return shared_ptr >(new LeveldbDataset()); + case DataParameter_DB_LMDB: + return shared_ptr >(new LmdbDataset()); + default: + LOG(FATAL) << "Unknown dataset type " << type; + return shared_ptr >(); + } +} + +template +shared_ptr > DatasetFactory(const string& type) { + if ("leveldb" == type) { + return DatasetFactory(DataParameter_DB_LEVELDB); + } else if ("lmdb" == type) { + return DatasetFactory(DataParameter_DB_LMDB); + } else { + LOG(FATAL) << "Unknown dataset type " << type; + return shared_ptr >(); + } +} + +#define REGISTER_DATASET(key_type, value_type) \ + template shared_ptr > \ + DatasetFactory(const string& type); \ + template shared_ptr > \ + DatasetFactory(const DataParameter_DB& type); \ + +REGISTER_DATASET(string, string); +REGISTER_DATASET(string, vector); +REGISTER_DATASET(string, Datum); + +#undef REGISTER_DATASET + +} // namespace caffe + + diff --git a/caffe-crfrnn/src/caffe/internal_thread.cpp b/caffe-crfrnn/src/caffe/internal_thread.cpp new file mode 100644 index 00000000..d7b6ae20 --- /dev/null +++ b/caffe-crfrnn/src/caffe/internal_thread.cpp @@ -0,0 +1,39 @@ +#include "caffe/internal_thread.hpp" + +#include "caffe/util/thread.hpp" + +namespace caffe { + +InternalThread::~InternalThread() { + WaitForInternalThreadToExit(); + if (thread_ != NULL) { + delete thread_; + } +} + +bool InternalThread::StartInternalThread() { + if (!WaitForInternalThreadToExit()) { + return false; + } + try { + thread_ = new caffe::Thread + (&InternalThread::InternalThreadEntry, this); + } catch (...) { + return false; + } + return true; +} + +/** Will not return until the internal thread has exited. */ +bool InternalThread::WaitForInternalThreadToExit() { + if (is_started()) { + try { + thread_->join(); + } catch (...) { + return false; + } + } + return true; +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layer_factory.cpp b/caffe-crfrnn/src/caffe/layer_factory.cpp new file mode 100644 index 00000000..5a286cd4 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layer_factory.cpp @@ -0,0 +1,158 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/layer_factory.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +// Get convolution layer according to engine. +template +Layer* GetConvolutionLayer( + const LayerParameter& param) { + ConvolutionParameter_Engine engine = param.convolution_param().engine(); + if (engine == ConvolutionParameter_Engine_DEFAULT) { + engine = ConvolutionParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = ConvolutionParameter_Engine_CUDNN; +#endif + } + if (engine == ConvolutionParameter_Engine_CAFFE) { + return new ConvolutionLayer(param); +#ifdef USE_CUDNN + } else if (engine == ConvolutionParameter_Engine_CUDNN) { + return new CuDNNConvolutionLayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(CONVOLUTION, GetConvolutionLayer); + +// Get pooling layer according to engine. +template +Layer* GetPoolingLayer(const LayerParameter& param) { + PoolingParameter_Engine engine = param.pooling_param().engine(); + if (engine == PoolingParameter_Engine_DEFAULT) { + engine = PoolingParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = PoolingParameter_Engine_CUDNN; +#endif + } + if (engine == PoolingParameter_Engine_CAFFE) { + return new PoolingLayer(param); +#ifdef USE_CUDNN + } else if (engine == PoolingParameter_Engine_CUDNN) { + PoolingParameter p_param = param.pooling_param(); + if (p_param.pad() || p_param.pad_h() || p_param.pad_w() || + param.top_size() > 1) { + LOG(INFO) << "CUDNN does not support padding or multiple tops. " + << "Using Caffe's own pooling layer."; + return new PoolingLayer(param); + } + return new CuDNNPoolingLayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(POOLING, GetPoolingLayer); + +// Get relu layer according to engine. +template +Layer* GetReLULayer(const LayerParameter& param) { + ReLUParameter_Engine engine = param.relu_param().engine(); + if (engine == ReLUParameter_Engine_DEFAULT) { + engine = ReLUParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = ReLUParameter_Engine_CUDNN; +#endif + } + if (engine == ReLUParameter_Engine_CAFFE) { + return new ReLULayer(param); +#ifdef USE_CUDNN + } else if (engine == ReLUParameter_Engine_CUDNN) { + return new CuDNNReLULayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(RELU, GetReLULayer); + +// Get sigmoid layer according to engine. +template +Layer* GetSigmoidLayer(const LayerParameter& param) { + SigmoidParameter_Engine engine = param.sigmoid_param().engine(); + if (engine == SigmoidParameter_Engine_DEFAULT) { + engine = SigmoidParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = SigmoidParameter_Engine_CUDNN; +#endif + } + if (engine == SigmoidParameter_Engine_CAFFE) { + return new SigmoidLayer(param); +#ifdef USE_CUDNN + } else if (engine == SigmoidParameter_Engine_CUDNN) { + return new CuDNNSigmoidLayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(SIGMOID, GetSigmoidLayer); + +// Get softmax layer according to engine. +template +Layer* GetSoftmaxLayer(const LayerParameter& param) { + SoftmaxParameter_Engine engine = param.softmax_param().engine(); + if (engine == SoftmaxParameter_Engine_DEFAULT) { + engine = SoftmaxParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = SoftmaxParameter_Engine_CUDNN; +#endif + } + if (engine == SoftmaxParameter_Engine_CAFFE) { + return new SoftmaxLayer(param); +#ifdef USE_CUDNN + } else if (engine == SoftmaxParameter_Engine_CUDNN) { + return new CuDNNSoftmaxLayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(SOFTMAX, GetSoftmaxLayer); + +// Get tanh layer according to engine. +template +Layer* GetTanHLayer(const LayerParameter& param) { + TanHParameter_Engine engine = param.tanh_param().engine(); + if (engine == TanHParameter_Engine_DEFAULT) { + engine = TanHParameter_Engine_CAFFE; +#ifdef USE_CUDNN + engine = TanHParameter_Engine_CUDNN; +#endif + } + if (engine == TanHParameter_Engine_CAFFE) { + return new TanHLayer(param); +#ifdef USE_CUDNN + } else if (engine == TanHParameter_Engine_CUDNN) { + return new CuDNNTanHLayer(param); +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(TANH, GetTanHLayer); + +// Layers that use their constructor as their default creator should be +// registered in their corresponding cpp files. Do not registere them here. +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/absval_layer.cpp b/caffe-crfrnn/src/caffe/layers/absval_layer.cpp new file mode 100644 index 00000000..0d054ee5 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/absval_layer.cpp @@ -0,0 +1,44 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/neuron_layers.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void AbsValLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + CHECK_NE(top[0], bottom[0]) << this->type_name() << " Layer does not " + "allow in-place computation."; +} + +template +void AbsValLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + const int count = top[0]->count(); + Dtype* top_data = top[0]->mutable_cpu_data(); + caffe_abs(count, bottom[0]->cpu_data(), top_data); +} + +template +void AbsValLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const int count = top[0]->count(); + const Dtype* top_diff = top[0]->cpu_diff(); + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_cpu_sign(count, bottom_data, bottom_diff); + caffe_mul(count, bottom_diff, top_diff, bottom_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(AbsValLayer); +#endif + +INSTANTIATE_CLASS(AbsValLayer); +REGISTER_LAYER_CLASS(ABSVAL, AbsValLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/absval_layer.cu b/caffe-crfrnn/src/caffe/layers/absval_layer.cu new file mode 100644 index 00000000..91f3c77f --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/absval_layer.cu @@ -0,0 +1,34 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void AbsValLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + const int count = top[0]->count(); + Dtype* top_data = top[0]->mutable_gpu_data(); + caffe_gpu_abs(count, bottom[0]->gpu_data(), top_data); +} + +template +void AbsValLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const int count = top[0]->count(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + caffe_gpu_sign(count, bottom_data, bottom_diff); + caffe_gpu_mul(count, bottom_diff, top_diff, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(AbsValLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/accuracy_layer.cpp b/caffe-crfrnn/src/caffe/layers/accuracy_layer.cpp new file mode 100644 index 00000000..c61e38f0 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/accuracy_layer.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void AccuracyLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + top_k_ = this->layer_param_.accuracy_param().top_k(); +} + +template +void AccuracyLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + CHECK_EQ(bottom[0]->num(), bottom[1]->num()) + << "The data and label should have the same number."; + CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) + << "top_k must be less than or equal to the number of classes."; + CHECK_EQ(bottom[1]->channels(), 1); + //CHECK_EQ(bottom[1]->height(), 1); + //CHECK_EQ(bottom[1]->width(), 1); + CHECK_EQ(bottom[1]->height(), bottom[0]->height()); + CHECK_EQ(bottom[1]->width(), bottom[0]->width()); + top[0]->Reshape(1, 1, 1, 1); +} + +template +void AccuracyLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + Dtype accuracy = 0; + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->channels(); + int height = bottom[0]->height(); + int width = bottom[0]->width(); + int gt_label; + int nKnownPixels=0; + vector maxval(top_k_+1); + vector max_id(top_k_+1); + + for (int i = 0; i < num; ++i) { + // Top-k accuracy + for (int h = 0; h < height; ++h){ + for (int w = 0; w < width; ++w){ + gt_label=static_cast(bottom_label[ (i * height + h) * width + w ]); + if (gt_label==255) + continue; + + ++nKnownPixels; + std::vector > bottom_data_vector; + for (int j = 0; j < dim; ++j) { + bottom_data_vector.push_back( + std::make_pair(bottom_data[((i * dim + j) * height + h)*width + w], j)); + } + std::partial_sort( + bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, + bottom_data_vector.end(), std::greater >()); + // check if true label is in top k predictions + for (int k = 0; k < top_k_; k++) { + if (bottom_data_vector[k].second == gt_label) { + ++accuracy; + break; + } + } + } + } + } + // LOG(INFO) << "Accuracy: " << accuracy; + top[0]->mutable_cpu_data()[0] = accuracy / Dtype(nKnownPixels); + // Accuracy layer should not be used as a loss function. +} + +INSTANTIATE_CLASS(AccuracyLayer); +REGISTER_LAYER_CLASS(ACCURACY, AccuracyLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/argmax_layer.cpp b/caffe-crfrnn/src/caffe/layers/argmax_layer.cpp new file mode 100644 index 00000000..15e199eb --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/argmax_layer.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ArgMaxLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + out_max_val_ = this->layer_param_.argmax_param().out_max_val(); + top_k_ = this->layer_param_.argmax_param().top_k(); + CHECK_GE(top_k_, 1) << " top k must not be less than 1."; + CHECK_LE(top_k_, bottom[0]->count() / bottom[0]->num()) + << "top_k must be less than or equal to the number of classes."; +} + +template +void ArgMaxLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + if (out_max_val_) { + // Produces max_ind and max_val + top[0]->Reshape(bottom[0]->num(), 2, top_k_, 1); + } else { + // Produces only max_ind + top[0]->Reshape(bottom[0]->num(), 1, top_k_, 1); + } +} + +template +void ArgMaxLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + for (int i = 0; i < num; ++i) { + std::vector > bottom_data_vector; + for (int j = 0; j < dim; ++j) { + bottom_data_vector.push_back( + std::make_pair(bottom_data[i * dim + j], j)); + } + std::partial_sort( + bottom_data_vector.begin(), bottom_data_vector.begin() + top_k_, + bottom_data_vector.end(), std::greater >()); + for (int j = 0; j < top_k_; ++j) { + top_data[top[0]->offset(i, 0, j)] = bottom_data_vector[j].second; + } + if (out_max_val_) { + for (int j = 0; j < top_k_; ++j) { + top_data[top[0]->offset(i, 1, j)] = bottom_data_vector[j].first; + } + } + } +} + +INSTANTIATE_CLASS(ArgMaxLayer); +REGISTER_LAYER_CLASS(ARGMAX, ArgMaxLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/base_conv_layer.cpp b/caffe-crfrnn/src/caffe/layers/base_conv_layer.cpp new file mode 100755 index 00000000..c53fc4a3 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/base_conv_layer.cpp @@ -0,0 +1,296 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { +//memory reduced change +template +Blob BaseConvolutionLayer::col_buffer_; +//end memory reduced change +template +void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Configure the kernel size, padding, stride, and inputs. + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + CHECK(!conv_param.has_kernel_size() != + !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; + CHECK(conv_param.has_kernel_size() || + (conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "For non-square filters both kernel_h and kernel_w are required."; + CHECK((!conv_param.has_pad() && conv_param.has_pad_h() + && conv_param.has_pad_w()) + || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) + << "pad is pad OR pad_h and pad_w are required."; + CHECK((!conv_param.has_stride() && conv_param.has_stride_h() + && conv_param.has_stride_w()) + || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) + << "Stride is stride OR stride_h and stride_w are required."; + if (conv_param.has_kernel_size()) { + kernel_h_ = kernel_w_ = conv_param.kernel_size(); + } else { + kernel_h_ = conv_param.kernel_h(); + kernel_w_ = conv_param.kernel_w(); + } + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + if (!conv_param.has_pad_h()) { + pad_h_ = pad_w_ = conv_param.pad(); + } else { + pad_h_ = conv_param.pad_h(); + pad_w_ = conv_param.pad_w(); + } + if (!conv_param.has_stride_h()) { + stride_h_ = stride_w_ = conv_param.stride(); + } else { + stride_h_ = conv_param.stride_h(); + stride_w_ = conv_param.stride_w(); + } + // Special case: im2col is the identity for 1x1 convolution with stride 1 + // and no padding, so flag for skipping the buffer and transformation. + is_1x1_ = kernel_w_ == 1 && kernel_h_ == 1 + && stride_h_ == 1 && stride_w_ == 1 && pad_h_ == 0 && pad_w_ == 0; + // Configure output channels and groups. + channels_ = bottom[0]->channels(); + num_output_ = this->layer_param_.convolution_param().num_output(); + CHECK_GT(num_output_, 0); + group_ = this->layer_param_.convolution_param().group(); + CHECK_EQ(channels_ % group_, 0); + CHECK_EQ(num_output_ % group_, 0) + << "Number of output should be multiples of group."; + if (reverse_dimensions()) { + conv_out_channels_ = channels_; + conv_in_channels_ = num_output_; + } else { + conv_out_channels_ = num_output_; + conv_in_channels_ = channels_; + } + // Handle the parameters: weights and biases. + // - blobs_[0] holds the filter weights + // - blobs_[1] holds the biases (optional) + bias_term_ = this->layer_param_.convolution_param().bias_term(); + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Initialize and fill the weights: + // output channels x input channels per-group x kernel height x kernel width + this->blobs_[0].reset(new Blob( + conv_out_channels_, conv_in_channels_ / group_, kernel_h_, kernel_w_)); + shared_ptr > weight_filler(GetFiller( + this->layer_param_.convolution_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, initialize and fill the biases: + // 1 x 1 x 1 x output channels + if (bias_term_) { + this->blobs_[1].reset(new Blob(1, 1, 1, num_output_)); + shared_ptr > bias_filler(GetFiller( + this->layer_param_.convolution_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } + // Propagate gradients to the parameters (as directed by backward pass). + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void BaseConvolutionLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + num_ = bottom[0]->num(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with" + " convolution kernel."; + // TODO: generalize to handle inputs of different shapes. + for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) { + CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num."; + CHECK_EQ(channels_, bottom[bottom_id]->channels()) + << "Inputs must have same channels."; + CHECK_EQ(height_, bottom[bottom_id]->height()) + << "Inputs must have same height."; + CHECK_EQ(width_, bottom[bottom_id]->width()) + << "Inputs must have same width."; + } + // Shape the tops. + compute_output_shape(); + for (int top_id = 0; top_id < top.size(); ++top_id) { + top[top_id]->Reshape(num_, num_output_, height_out_, width_out_); + } + if (reverse_dimensions()) { + conv_in_height_ = height_out_; + conv_in_width_ = width_out_; + conv_out_spatial_dim_ = height_ * width_; + } else { + conv_in_height_ = height_; + conv_in_width_ = width_; + conv_out_spatial_dim_ = height_out_ * width_out_; + } + kernel_dim_ = conv_in_channels_ * kernel_h_ * kernel_w_; + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_ / group_; + col_offset_ = kernel_dim_ * conv_out_spatial_dim_ / group_; + output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // The im2col result buffer will only hold one image at a time to avoid + // overly large memory usage. In the special case of 1x1 convolution + // it goes lazily unused to save memory. + if (reverse_dimensions()) { + col_buffer_.Reshape(1, kernel_dim_, height_, width_); + } else { + col_buffer_.Reshape(1, kernel_dim_, height_out_, width_out_); + } + // Set up the all ones "bias multiplier" for adding biases by BLAS + if (bias_term_) { + bias_multiplier_.Reshape(1, 1, 1, height_out_ * width_out_); + caffe_set(bias_multiplier_.count(), Dtype(1), + bias_multiplier_.mutable_cpu_data()); + } +} + +template +void BaseConvolutionLayer::forward_cpu_gemm(const Dtype* input, + const Dtype* weights, Dtype* output, bool skip_im2col) { + const Dtype* col_buff = input; + if (!is_1x1_) { + if (!skip_im2col) { + conv_im2col_cpu(input, col_buffer_.mutable_cpu_data()); + } + col_buff = col_buffer_.cpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / + group_, conv_out_spatial_dim_, kernel_dim_ / group_, + (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, + (Dtype)0., output + output_offset_ * g); + } +} + +template +void BaseConvolutionLayer::forward_cpu_bias(Dtype* output, + const Dtype* bias) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, + height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.cpu_data(), + (Dtype)1., output); +} + +template +void BaseConvolutionLayer::backward_cpu_gemm(const Dtype* output, + const Dtype* weights, Dtype* input) { + Dtype* col_buff = col_buffer_.mutable_cpu_data(); + if (is_1x1_) { + col_buff = input; + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + conv_out_spatial_dim_, conv_out_channels_ / group_, + (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, + (Dtype)0., col_buff + col_offset_ * g); + } + if (!is_1x1_) { + conv_col2im_cpu(col_buff, input); + } +} + +template +void BaseConvolutionLayer::weight_cpu_gemm(const Dtype* input, + const Dtype* output, Dtype* weights) { + const Dtype* col_buff = input; + if (!is_1x1_) { + conv_im2col_cpu(input, col_buffer_.mutable_cpu_data()); + col_buff = col_buffer_.cpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, + kernel_dim_ / group_, conv_out_spatial_dim_, + (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, + (Dtype)1., weights + weight_offset_ * g); + } +} + +template +void BaseConvolutionLayer::backward_cpu_bias(Dtype* bias, + const Dtype* input) { + caffe_cpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + input, bias_multiplier_.cpu_data(), 1., bias); +} + +#ifndef CPU_ONLY + +template +void BaseConvolutionLayer::forward_gpu_gemm(const Dtype* input, + const Dtype* weights, Dtype* output, bool skip_im2col) { + const Dtype* col_buff = input; + if (!is_1x1_) { + if (!skip_im2col) { + conv_im2col_gpu(input, col_buffer_.mutable_gpu_data()); + } + col_buff = col_buffer_.gpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, conv_out_channels_ / + group_, conv_out_spatial_dim_, kernel_dim_ / group_, + (Dtype)1., weights + weight_offset_ * g, col_buff + col_offset_ * g, + (Dtype)0., output + output_offset_ * g); + } +} + +template +void BaseConvolutionLayer::forward_gpu_bias(Dtype* output, + const Dtype* bias) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num_output_, + height_out_ * width_out_, 1, (Dtype)1., bias, bias_multiplier_.gpu_data(), + (Dtype)1., output); +} + +template +void BaseConvolutionLayer::backward_gpu_gemm(const Dtype* output, + const Dtype* weights, Dtype* input) { + Dtype* col_buff = col_buffer_.mutable_gpu_data(); + if (is_1x1_) { + col_buff = input; + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasTrans, CblasNoTrans, kernel_dim_ / group_, + conv_out_spatial_dim_, conv_out_channels_ / group_, + (Dtype)1., weights + weight_offset_ * g, output + output_offset_ * g, + (Dtype)0., col_buff + col_offset_ * g); + } + if (!is_1x1_) { + conv_col2im_gpu(col_buff, input); + } +} + +template +void BaseConvolutionLayer::weight_gpu_gemm(const Dtype* input, + const Dtype* output, Dtype* weights) { + const Dtype* col_buff = input; + if (!is_1x1_) { + conv_im2col_gpu(input, col_buffer_.mutable_gpu_data()); + col_buff = col_buffer_.gpu_data(); + } + for (int g = 0; g < group_; ++g) { + caffe_gpu_gemm(CblasNoTrans, CblasTrans, conv_out_channels_ / group_, + kernel_dim_ / group_, conv_out_spatial_dim_, + (Dtype)1., output + output_offset_ * g, col_buff + col_offset_ * g, + (Dtype)1., weights + weight_offset_ * g); + } +} + +template +void BaseConvolutionLayer::backward_gpu_bias(Dtype* bias, + const Dtype* input) { + caffe_gpu_gemv(CblasNoTrans, num_output_, height_out_ * width_out_, 1., + input, bias_multiplier_.gpu_data(), 1., bias); +} + +#endif // !CPU_ONLY + +INSTANTIATE_CLASS(BaseConvolutionLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/base_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/base_data_layer.cpp new file mode 100644 index 00000000..eb0aaf82 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/base_data_layer.cpp @@ -0,0 +1,84 @@ +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/util/io.hpp" + +namespace caffe { + +template +BaseDataLayer::BaseDataLayer(const LayerParameter& param) + : Layer(param), + transform_param_(param.transform_param()), + data_transformer_(transform_param_) { +} + +template +void BaseDataLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + if (top.size() == 1) { + output_labels_ = false; + } else { + output_labels_ = true; + } + // The subclasses should setup the size of bottom and top + DataLayerSetUp(bottom, top); + data_transformer_.InitRand(); +} + +template +void BasePrefetchingDataLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + BaseDataLayer::LayerSetUp(bottom, top); + // Now, start the prefetch thread. Before calling prefetch, we make two + // cpu_data calls so that the prefetch thread does not accidentally make + // simultaneous cudaMalloc calls when the main thread is running. In some + // GPUs this seems to cause failures if we do not so. + this->prefetch_data_.mutable_cpu_data(); + if (this->output_labels_) { + this->prefetch_label_.mutable_cpu_data(); + } + DLOG(INFO) << "Initializing prefetch"; + this->CreatePrefetchThread(); + DLOG(INFO) << "Prefetch initialized."; +} + +template +void BasePrefetchingDataLayer::CreatePrefetchThread() { + this->phase_ = Caffe::phase(); + this->data_transformer_.InitRand(); + CHECK(StartInternalThread()) << "Thread execution failed"; +} + +template +void BasePrefetchingDataLayer::JoinPrefetchThread() { + CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; +} + +template +void BasePrefetchingDataLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + // First, join the thread + JoinPrefetchThread(); + DLOG(INFO) << "Thread joined"; + // Copy the data + caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + top[0]->mutable_cpu_data()); + DLOG(INFO) << "Prefetch copied"; + if (this->output_labels_) { + caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + top[1]->mutable_cpu_data()); + } + // Start a new prefetch thread + DLOG(INFO) << "CreatePrefetchThread"; + CreatePrefetchThread(); +} + +#ifdef CPU_ONLY +STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward); +#endif + +INSTANTIATE_CLASS(BaseDataLayer); +INSTANTIATE_CLASS(BasePrefetchingDataLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/base_data_layer.cu b/caffe-crfrnn/src/caffe/layers/base_data_layer.cu new file mode 100644 index 00000000..204a16d2 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/base_data_layer.cu @@ -0,0 +1,25 @@ +#include + +#include "caffe/data_layers.hpp" + +namespace caffe { + +template +void BasePrefetchingDataLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + // First, join the thread + JoinPrefetchThread(); + // Copy the data + caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + top[0]->mutable_gpu_data()); + if (this->output_labels_) { + caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + top[1]->mutable_gpu_data()); + } + // Start a new prefetch thread + CreatePrefetchThread(); +} + +INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/bnll_layer.cpp b/caffe-crfrnn/src/caffe/layers/bnll_layer.cpp new file mode 100644 index 00000000..cb3583ae --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/bnll_layer.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +const float kBNLL_THRESHOLD = 50.; + +template +void BNLLLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] > 0 ? + bottom_data[i] + log(1. + exp(-bottom_data[i])) : + log(1. + exp(bottom_data[i])); + } +} + +template +void BNLLLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + Dtype expval; + for (int i = 0; i < count; ++i) { + expval = exp(std::min(bottom_data[i], Dtype(kBNLL_THRESHOLD))); + bottom_diff[i] = top_diff[i] * expval / (expval + 1.); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(BNLLLayer); +#endif + +INSTANTIATE_CLASS(BNLLLayer); +REGISTER_LAYER_CLASS(BNLL, BNLLLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/bnll_layer.cu b/caffe-crfrnn/src/caffe/layers/bnll_layer.cu new file mode 100644 index 00000000..d963d068 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/bnll_layer.cu @@ -0,0 +1,60 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +const float kBNLL_THRESHOLD = 50.; + +template +__global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = in[index] > 0 ? + in[index] + log(1. + exp(-in[index])) : + log(1. + exp(in[index])); + } +} + +template +void BNLLLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + BNLLForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void BNLLBackward(const int n, const Dtype* in_diff, + const Dtype* in_data, Dtype* out_diff) { + CUDA_KERNEL_LOOP(index, n) { + Dtype expval = exp(min(in_data[index], Dtype(kBNLL_THRESHOLD))); + out_diff[index] = in_diff[index] * expval / (expval + 1.); + } +} + +template +void BNLLLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + BNLLBackward<<>>( + count, top_diff, bottom_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(BNLLLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/concat_layer.cpp b/caffe-crfrnn/src/caffe/layers/concat_layer.cpp new file mode 100644 index 00000000..42082195 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/concat_layer.cpp @@ -0,0 +1,109 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ConcatLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + concat_dim_ = this->layer_param_.concat_param().concat_dim(); + CHECK_GE(concat_dim_, 0) << + "concat_dim should be >= 0"; + CHECK_LE(concat_dim_, 1) << + "For now concat_dim <=1, it can only concat num and channels"; +} + +template +void ConcatLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // Initialize with the first blob. + count_ = bottom[0]->count(); + num_ = bottom[0]->num(); + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + for (int i = 1; i < bottom.size(); ++i) { + count_ += bottom[i]->count(); + if (concat_dim_== 0) { + num_ += bottom[i]->num(); + } else if (concat_dim_ == 1) { + channels_ += bottom[i]->channels(); + } else if (concat_dim_ == 2) { + height_ += bottom[i]->height(); + } else if (concat_dim_ == 3) { + width_ += bottom[i]->width(); + } + } + top[0]->Reshape(num_, channels_, height_, width_); + CHECK_EQ(count_, top[0]->count()); +} + +template +void ConcatLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + Dtype* top_data = top[0]->mutable_cpu_data(); + if (concat_dim_== 0) { + int offset_num = 0; + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + int num_elem = bottom[i]->count(); + caffe_copy(num_elem, bottom_data, top_data+top[0]->offset(offset_num)); + offset_num += bottom[i]->num(); + } + } else if (concat_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + int num_elem = + bottom[i]->channels()*bottom[i]->height()*bottom[i]->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, bottom_data+bottom[i]->offset(n), + top_data+top[0]->offset(n, offset_channel)); + } + offset_channel += bottom[i]->channels(); + } // concat_dim_ is guaranteed to be 0 or 1 by LayerSetUp. + } +} + +template +void ConcatLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + if (concat_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < bottom.size(); ++i) { + Blob* blob = bottom[i]; + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_cpu_diff(); + caffe_copy(blob->count(), top_diff + top[0]->offset(offset_num), + bottom_diff); + } + offset_num += blob->num(); + } + } else if (concat_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < bottom.size(); ++i) { + Blob* blob = bottom[i]; + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_cpu_diff(); + int num_elem = blob->channels()*blob->height()*blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel), + bottom_diff + blob->offset(n)); + } + } + offset_channel += blob->channels(); + } + } // concat_dim_ is guaranteed to be 0 or 1 by LayerSetUp. +} + +#ifdef CPU_ONLY +STUB_GPU(ConcatLayer); +#endif + +INSTANTIATE_CLASS(ConcatLayer); +REGISTER_LAYER_CLASS(CONCAT, ConcatLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/concat_layer.cu b/caffe-crfrnn/src/caffe/layers/concat_layer.cu new file mode 100644 index 00000000..88fc0900 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/concat_layer.cu @@ -0,0 +1,76 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ConcatLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + Dtype* top_data = top[0]->mutable_gpu_data(); + if (concat_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + caffe_copy(bottom[i]->count(), bottom_data, + top_data + top[0]->offset(offset_num)); + offset_num += bottom[i]->num(); + } + } else if (concat_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + int num_elem = + bottom[i]->channels() * bottom[i]->height() * bottom[i]->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, bottom_data+bottom[i]->offset(n), + top_data + top[0]->offset(n, offset_channel)); + } + offset_channel += bottom[i]->channels(); + } + } else { + LOG(FATAL) << "concat_dim along dim" << concat_dim_ << + " not implemented yet"; + } +} + +template +void ConcatLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + if (concat_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < bottom.size(); ++i) { + Blob* blob = bottom[i]; + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_gpu_diff(); + caffe_copy(blob->count(), top_diff + top[0]->offset(offset_num), + bottom_diff); + } + offset_num += blob->num(); + } + } else if (concat_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < bottom.size(); ++i) { + Blob* blob = bottom[i]; + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_gpu_diff(); + int num_elem = blob->channels()*blob->height()*blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel), + bottom_diff + blob->offset(n)); + } + } + offset_channel += blob->channels(); + } + } else { + LOG(FATAL) << "concat_dim along dim" << concat_dim_ << + " not implemented yet"; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ConcatLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cpp new file mode 100644 index 00000000..0d0b443b --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cpp @@ -0,0 +1,101 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/loss_layers.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void ContrastiveLossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + LossLayer::LayerSetUp(bottom, top); + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->height(), 1); + CHECK_EQ(bottom[0]->width(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); + CHECK_EQ(bottom[2]->channels(), 1); + CHECK_EQ(bottom[2]->height(), 1); + CHECK_EQ(bottom[2]->width(), 1); + diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1); + diff_sq_.Reshape(bottom[0]->num(), bottom[0]->channels(), 1, 1); + dist_sq_.Reshape(bottom[0]->num(), 1, 1, 1); + // vector of ones used to sum along channels + summer_vec_.Reshape(bottom[0]->channels(), 1, 1, 1); + for (int i = 0; i < bottom[0]->channels(); ++i) + summer_vec_.mutable_cpu_data()[i] = Dtype(1); +} + +template +void ContrastiveLossLayer::Forward_cpu( + const vector*>& bottom, + const vector*>& top) { + int count = bottom[0]->count(); + caffe_sub( + count, + bottom[0]->cpu_data(), // a + bottom[1]->cpu_data(), // b + diff_.mutable_cpu_data()); // a_i-b_i + const int channels = bottom[0]->channels(); + Dtype margin = this->layer_param_.contrastive_loss_param().margin(); + Dtype loss(0.0); + for (int i = 0; i < bottom[0]->num(); ++i) { + dist_sq_.mutable_cpu_data()[i] = caffe_cpu_dot(channels, + diff_.cpu_data() + (i*channels), diff_.cpu_data() + (i*channels)); + if (static_cast(bottom[2]->cpu_data()[i])) { // similar pairs + loss += dist_sq_.cpu_data()[i]; + } else { // dissimilar pairs + loss += std::max(margin-dist_sq_.cpu_data()[i], Dtype(0.0)); + } + } + loss = loss / static_cast(bottom[0]->num()) / Dtype(2); + top[0]->mutable_cpu_data()[0] = loss; +} + +template +void ContrastiveLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + Dtype margin = this->layer_param_.contrastive_loss_param().margin(); + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / + static_cast(bottom[i]->num()); + int num = bottom[i]->num(); + int channels = bottom[i]->channels(); + for (int j = 0; j < num; ++j) { + Dtype* bout = bottom[i]->mutable_cpu_diff(); + if (static_cast(bottom[2]->cpu_data()[j])) { // similar pairs + caffe_cpu_axpby( + channels, + alpha, + diff_.cpu_data() + (j*channels), + Dtype(0.0), + bout + (j*channels)); + } else { // dissimilar pairs + if ((margin-dist_sq_.cpu_data()[j]) > Dtype(0.0)) { + caffe_cpu_axpby( + channels, + -alpha, + diff_.cpu_data() + (j*channels), + Dtype(0.0), + bout + (j*channels)); + } else { + caffe_set(channels, Dtype(0), bout + (j*channels)); + } + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(ContrastiveLossLayer); +#endif + +INSTANTIATE_CLASS(ContrastiveLossLayer); +REGISTER_LAYER_CLASS(CONTRASTIVE_LOSS, ContrastiveLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cu b/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cu new file mode 100644 index 00000000..78a55995 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/contrastive_loss_layer.cu @@ -0,0 +1,91 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ContrastiveLossLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + const int count = bottom[0]->count(); + caffe_gpu_sub( + count, + bottom[0]->gpu_data(), // a + bottom[1]->gpu_data(), // b + diff_.mutable_gpu_data()); // a_i-b_i + caffe_gpu_powx( + count, + diff_.mutable_gpu_data(), // a_i-b_i + Dtype(2), + diff_sq_.mutable_gpu_data()); // (a_i-b_i)^2 + caffe_gpu_gemv( + CblasNoTrans, + bottom[0]->num(), + bottom[0]->channels(), + Dtype(1.0), + diff_sq_.gpu_data(), // (a_i-b_i)^2 + summer_vec_.gpu_data(), + Dtype(0.0), + dist_sq_.mutable_gpu_data()); // \Sum (a_i-b_i)^2 + Dtype margin = this->layer_param_.contrastive_loss_param().margin(); + Dtype loss(0.0); + for (int i = 0; i < bottom[0]->num(); ++i) { + if (static_cast(bottom[2]->cpu_data()[i])) { // similar pairs + loss += dist_sq_.cpu_data()[i]; + } else { // dissimilar pairs + loss += std::max(margin-dist_sq_.cpu_data()[i], Dtype(0.0)); + } + } + loss = loss / static_cast(bottom[0]->num()) / Dtype(2); + top[0]->mutable_cpu_data()[0] = loss; +} + +template +__global__ void CLLForward(const int count, const int channels, + const Dtype margin, const Dtype alpha, + const Dtype* y, const Dtype* diff, const Dtype* dist_sq, + Dtype *bottom_diff) { + CUDA_KERNEL_LOOP(i, count) { + int n = i / channels; // the num index, to access y and dist_sq + if (static_cast(y[n])) { // similar pairs + bottom_diff[i] = alpha * diff[i]; + } else { // dissimilar pairs + if ((margin-dist_sq[n]) > 0.0) { + bottom_diff[i] = -alpha * diff[i]; + } else { + bottom_diff[i] = 0; + } + } + } +} + +template +void ContrastiveLossLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const int count = bottom[0]->count(); + const int channels = bottom[0]->channels(); + Dtype margin = this->layer_param_.contrastive_loss_param().margin(); + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / + static_cast(bottom[0]->num()); + // NOLINT_NEXT_LINE(whitespace/operators) + CLLForward<<>>( + count, channels, margin, alpha, + bottom[2]->gpu_data(), // pair similarity 0 or 1 + diff_.gpu_data(), // the cached eltwise difference between a and b + dist_sq_.gpu_data(), // the cached square distance between a and b + bottom[i]->mutable_gpu_diff()); + CUDA_POST_KERNEL_CHECK; + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ContrastiveLossLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/conv_layer.cpp b/caffe-crfrnn/src/caffe/layers/conv_layer.cpp new file mode 100644 index 00000000..d965ef75 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/conv_layer.cpp @@ -0,0 +1,75 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ConvolutionLayer::compute_output_shape() { + this->height_out_ = (this->height_ + 2 * this->pad_h_ - this->kernel_h_) + / this->stride_h_ + 1; + this->width_out_ = (this->width_ + 2 * this->pad_w_ - this->kernel_w_) + / this->stride_w_ + 1; +} + +template +void ConvolutionLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* top_data = top[i]->mutable_cpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void ConvolutionLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->cpu_diff(); + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + // gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_cpu_gemm(bottom_data + bottom[i]->offset(n), + top_diff + top[i]->offset(n), weight_diff); + } + // gradient w.r.t. bottom data, if necessary. + if (propagate_down[i]) { + this->backward_cpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(ConvolutionLayer); +#endif + +INSTANTIATE_CLASS(ConvolutionLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/conv_layer.cu b/caffe-crfrnn/src/caffe/layers/conv_layer.cu new file mode 100644 index 00000000..b8a98ff7 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/conv_layer.cu @@ -0,0 +1,64 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ConvolutionLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* top_data = top[i]->mutable_gpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->forward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->gpu_data(); + this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void ConvolutionLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + for (int n = 0; n < this->num_; ++n) { + // gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_gpu_gemm(bottom_data + bottom[i]->offset(n), + top_diff + top[i]->offset(n), weight_diff); + } + // gradient w.r.t. bottom data, if necessary. + if (propagate_down[i]) { + this->backward_gpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); + } + } + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ConvolutionLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/crop_layer.cpp b/caffe-crfrnn/src/caffe/layers/crop_layer.cpp new file mode 100644 index 00000000..f07012f2 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/crop_layer.cpp @@ -0,0 +1,126 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/net.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CropLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Construct a map from top blobs to layer inds, skipping over in-place + // connections. + map*, int> down_map; + for (int layer_ind = 0; layer_ind < this->net_->top_vecs().size(); + ++layer_ind) { + vector*> tops = this->net_->top_vecs()[layer_ind]; + for (int top_ind = 0; top_ind < tops.size(); ++top_ind) { + if (down_map.find(tops[top_ind]) == down_map.end()) { + down_map[tops[top_ind]] = layer_ind; + } + } + } + // Walk back from the first bottom, keeping track of all the blobs we pass. + set*> path_blobs; + Blob* blob = bottom[0]; + int layer_ind; + // TODO this logic can be simplified if all blobs are tops + path_blobs.insert(blob); + while (down_map.find(blob) != down_map.end()) { + layer_ind = down_map[blob]; + if (this->net_->bottom_vecs()[layer_ind].size() == 0) { + break; + } + blob = this->net_->bottom_vecs()[layer_ind][0]; + path_blobs.insert(blob); + } + // Now walk back from the second bottom, until we find a blob of intersection. + Blob* inter_blob = bottom[1]; + while (path_blobs.find(inter_blob) == path_blobs.end()) { + CHECK(down_map.find(inter_blob) != down_map.end()) + << "Cannot align apparently disconnected blobs."; + layer_ind = down_map[inter_blob]; + CHECK_GT(this->net_->bottom_vecs()[layer_ind].size(), 0) + << "Cannot align apparently disconnected blobs."; + inter_blob = this->net_->bottom_vecs()[layer_ind][0]; + } + // Compute the coord map from the blob of intersection to each bottom. + vector > coord_maps(2, + DiagonalAffineMap::identity(2)); + for (int i = 0; i < 2; ++i) { + for (Blob* blob = bottom[i]; blob != inter_blob; + blob = this->net_->bottom_vecs()[down_map[blob]][0]) { + shared_ptr > layer = this->net_->layers()[down_map[blob]]; + coord_maps[i] = coord_maps[i].compose(layer->coord_map()); + } + } + // Compute the mapping from first bottom coordinates to second. + DiagonalAffineMap crop_map = + coord_maps[1].compose(coord_maps[0].inv()); + for (int i = 0; i < 2; ++i) { + // Check for scale mismatch (unfortunately, CHECK_DOUBLE_EQ does not + // support a message like the other CHECKs). + CHECK_DOUBLE_EQ(crop_map.coefs()[i].first, 1); + CHECK_LE(crop_map.coefs()[i].second, 0) << "Negative crop width."; + // Check that the crop width is an integer. + CHECK_DOUBLE_EQ(crop_map.coefs()[i].second, + round(crop_map.coefs()[i].second)); + } + crop_h_ = - round(crop_map.coefs()[0].second); + crop_w_ = - round(crop_map.coefs()[1].second); +} + +template +void CropLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), bottom[1]->height(), + bottom[1]->width()); +} + +template +void CropLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < top[0]->channels(); ++c) { + for (int h = 0; h < top[0]->height(); ++h) { + caffe_copy(top[0]->width(), + bottom_data + bottom[0]->offset(n, c, crop_h_ + h, crop_w_), + top_data + top[0]->offset(n, c, h)); + } + } + } +} + +template +void CropLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + if (propagate_down[0]) { + caffe_set(bottom[0]->count(), static_cast(0), bottom_diff); + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < top[0]->channels(); ++c) { + for (int h = 0; h < top[0]->height(); ++h) { + caffe_copy(top[0]->width(), + top_diff + top[0]->offset(n, c, h), + bottom_diff + bottom[0]->offset(n, c, crop_h_ + h, crop_w_)); + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(CropLayer); +#endif + +INSTANTIATE_CLASS(CropLayer); +REGISTER_LAYER_CLASS(CROP, CropLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/crop_layer.cu b/caffe-crfrnn/src/caffe/layers/crop_layer.cu new file mode 100644 index 00000000..2dd3ff95 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/crop_layer.cu @@ -0,0 +1,60 @@ +#include + +#include "caffe/vision_layers.hpp" + +namespace caffe { + +// Copy (one line per thread) from one array to another, with arbitrary +// strides in the last two dimensions. +template +__global__ void copy_kernel(const int n, const int height, const int width, + const int src_outer_stride, const int src_inner_stride, + const int dest_outer_stride, const int dest_inner_stride, + const Dtype* src, Dtype* dest) { + CUDA_KERNEL_LOOP(index, n) { + int src_start = index / height * src_outer_stride + + index % height * src_inner_stride; + int dest_start = index / height * dest_outer_stride + + index % height * dest_inner_stride; + for (int i = 0; i < width; ++i) { + dest[dest_start + i] = src[src_start + i]; + } + } +} + +template +void CropLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int lines = top[0]->count() / top[0]->width(); + + // NOLINT_NEXT_LINE(whitespace/operators) + copy_kernel<<>>( + lines, top[0]->height(), top[0]->width(), + bottom[0]->height() * bottom[0]->width(), bottom[0]->width(), + top[0]->height() * top[0]->width(), top[0]->width(), + bottom_data + bottom[0]->offset(0, 0, crop_h_, crop_w_), top_data); +} + +template +void CropLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int lines = top[0]->count() / top[0]->width(); + + if (propagate_down[0]) { + caffe_gpu_set(bottom[0]->count(), static_cast(0), bottom_diff); + // NOLINT_NEXT_LINE(whitespace/operators) + copy_kernel<<>>( + lines, top[0]->height(), top[0]->width(), + top[0]->height() * top[0]->width(), top[0]->width(), + bottom[0]->height() * bottom[0]->width(), bottom[0]->width(), + top_diff, bottom_diff + bottom[0]->offset(0, 0, crop_h_, crop_w_)); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(CropLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cpp new file mode 100644 index 00000000..5f3a7773 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cpp @@ -0,0 +1,123 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +// Set to three for the benefit of the backward pass, which +// can use separate streams for calculating the gradient w.r.t. +// bias, filter weights, and bottom data for each group independently +#define CUDNN_STREAMS_PER_GROUP 3 + +/** + * TODO(dox) explain cuDNN interface + */ +template +void CuDNNConvolutionLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + ConvolutionLayer::LayerSetUp(bottom, top); + // Initialize CUDA streams and cuDNN. + stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + + for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { + CUDA_CHECK(cudaStreamCreate(&stream_[g])); + CUDNN_CHECK(cudnnCreate(&handle_[g])); + CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g])); + } + + // Set the indexing parameters. + weight_offset_ = (this->num_output_ / this->group_) + * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_; + bias_offset_ = (this->num_output_ / this->group_); + + // Create filter descriptor. + cudnn::createFilterDesc(&filter_desc_, + this->num_output_ / this->group_, this->channels_ / this->group_, + this->kernel_h_, this->kernel_w_); + + // Create tensor descriptor(s) for data and corresponding convolution(s). + for (int i = 0; i < bottom.size(); i++) { + cudnnTensorDescriptor_t bottom_desc; + cudnn::createTensor4dDesc(&bottom_desc); + bottom_descs_.push_back(bottom_desc); + cudnnTensorDescriptor_t top_desc; + cudnn::createTensor4dDesc(&top_desc); + top_descs_.push_back(top_desc); + cudnnConvolutionDescriptor_t conv_desc; + cudnn::createConvolutionDesc(&conv_desc); + conv_descs_.push_back(conv_desc); + } + + // Tensor descriptor for bias. + if (this->bias_term_) { + cudnn::createTensor4dDesc(&bias_desc_); + } +} + +template +void CuDNNConvolutionLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + ConvolutionLayer::Reshape(bottom, top); + bottom_offset_ = (this->channels_ / this->group_) + * this->height_ * this->width_; + top_offset_ = (this->num_output_ / this->group_) + * this->height_out_ * this->width_out_; + + for (int i = 0; i < bottom.size(); i++) { + cudnn::setTensor4dDesc(&bottom_descs_[i], + this->num_, + this->channels_ / this->group_, + this->height_, this->width_, + this->channels_ * this->height_ * this->width_, + this->height_ * this->width_, + this->width_, 1); + cudnn::setTensor4dDesc(&top_descs_[i], + this->num_, + this->num_output_ / this->group_, + this->height_out_, this->width_out_, + this->num_output_ * this->height_out_ * this->width_out_, + this->height_out_ * this->width_out_, + this->width_out_, 1); + cudnn::setConvolutionDesc(&conv_descs_[i], bottom_descs_[i], + filter_desc_, this->pad_h_, this->pad_w_, + this->stride_h_, this->stride_w_); + } + + // Tensor descriptor for bias. + if (this->bias_term_) { + cudnn::setTensor4dDesc(&bias_desc_, + 1, this->num_output_ / this->group_, 1, 1); + } +} + +template +CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { + for (int i = 0; i < bottom_descs_.size(); i++) { + cudnnDestroyTensorDescriptor(bottom_descs_[i]); + cudnnDestroyTensorDescriptor(top_descs_[i]); + cudnnDestroyConvolutionDescriptor(conv_descs_[i]); + } + if (this->bias_term_) { + cudnnDestroyTensorDescriptor(bias_desc_); + } + cudnnDestroyFilterDescriptor(filter_desc_); + + for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { + cudaStreamDestroy(stream_[g]); + cudnnDestroy(handle_[g]); + } + + delete [] stream_; + delete [] handle_; +} + +INSTANTIATE_CLASS(CuDNNConvolutionLayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cu new file mode 100644 index 00000000..5ac30019 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_conv_layer.cu @@ -0,0 +1,149 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +__global__ void sync_conv_groups() { } + +template +void CuDNNConvolutionLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* top_data = top[i]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + + // Forward through cuDNN in parallel over groups. + for (int g = 0; g < this->group_; g++) { + cudnnConvolutionFwdAlgo_t algo; + + // get the desired convolution algorithm + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, // memoryLimitInBytes, + &algo)); + + // get minimum size of the workspace needed for the desired algorithm + size_t workspaceSizeInBytes_temp = 0; + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + algo, + &workspaceSizeInBytes)); + + if (workspaceSizeInBytes_temp > workspaceSizeInBytes) { + workspaceSizeInBytes = workspaceSizeInBytes_temp; + // free the existing workspace and allocate a new (larger) one + cudaFree(this->workspace); + cudaMalloc(&(this->workspace), workspaceSizeInBytes); + } + + // Filters. + CUDNN_CHECK(cudnnConvolutionForward(handle_[g], + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + filter_desc_, weight + weight_offset_ * g, + conv_descs_[i], + algo, workspace, workspaceSizeInBytes, + cudnn::dataType::zero, + top_descs_[i], top_data + top_offset_ * g)); + + // Bias. + if (this->bias_term_) { + const Dtype* bias_data = this->blobs_[1]->gpu_data(); + CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C, + cudnn::dataType::one, + bias_desc_, bias_data + bias_offset_ * g, + cudnn::dataType::one, + top_descs_[i], top_data + top_offset_ * g)); + } + } + + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_conv_groups<<<1, 1>>>(); + } +} + +template +void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = NULL; + Dtype* weight_diff = NULL; + if (this->param_propagate_down_[0]) { + weight = this->blobs_[0]->gpu_data(); + weight_diff = this->blobs_[0]->mutable_gpu_diff(); + caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); + } + Dtype* bias_diff = NULL; + if (this->bias_term_ && this->param_propagate_down_[1]) { + bias_diff = this->blobs_[1]->mutable_gpu_diff(); + caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff); + } + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + // Backward through cuDNN in parallel over groups and gradients. + for (int g = 0; g < this->group_; g++) { + // Gradient w.r.t. bias. + if (this->bias_term_ && this->param_propagate_down_[1]) { + CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g], + cudnn::dataType::one, + top_descs_[i], top_diff + top_offset_ * g, + cudnn::dataType::one, + bias_desc_, bias_diff + bias_offset_ * g)); + } + + // Gradient w.r.t. weights. + if (this->param_propagate_down_[0]) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g], + cudnn::dataType::one, + bottom_descs_[i], bottom_data + bottom_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::one, + filter_desc_, weight_diff + weight_offset_ * g)); + } + + // Gradient w.r.t. bottom data. + if (propagate_down[i]) { + if (weight == NULL) { + weight = this->blobs_[0]->gpu_data(); + } + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], + cudnn::dataType::one, + filter_desc_, weight + weight_offset_ * g, + top_descs_[i], top_diff + top_offset_ * g, + conv_descs_[i], + cudnn::dataType::zero, + bottom_descs_[i], bottom_diff + bottom_offset_ * g)); + } + } + + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_conv_groups<<<1, 1>>>(); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNConvolutionLayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cpp new file mode 100644 index 00000000..3e63ab3b --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cpp @@ -0,0 +1,46 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNPoolingLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + PoolingLayer::LayerSetUp(bottom, top); + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); + cudnn::createPoolingDesc(&pooling_desc_, + this->layer_param_.pooling_param().pool(), &mode_, + this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_, + this->stride_h_, this->stride_w_); +} + +template +void CuDNNPoolingLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + PoolingLayer::Reshape(bottom, top); + cudnn::setTensor4dDesc(&bottom_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + cudnn::setTensor4dDesc(&top_desc_, bottom[0]->num(), + this->channels_, this->pooled_height_, this->pooled_width_); +} + +template +CuDNNPoolingLayer::~CuDNNPoolingLayer() { + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + cudnnDestroyPoolingDescriptor(pooling_desc_); + cudnnDestroy(handle_); +} + +INSTANTIATE_CLASS(CuDNNPoolingLayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cu new file mode 100644 index 00000000..54eb9461 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_pooling_layer.cu @@ -0,0 +1,46 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNPoolingLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_, + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data)); +} + +template +void CuDNNPoolingLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_, + cudnn::dataType::one, + top_desc_, top_data, top_desc_, top_diff, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + bottom_desc_, bottom_diff)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cpp new file mode 100644 index 00000000..783e71be --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cpp @@ -0,0 +1,42 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNReLULayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + ReLULayer::LayerSetUp(bottom, top); + // initialize cuDNN + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNReLULayer::Reshape(const vector*>& bottom, + const vector*>& top) { + ReLULayer::Reshape(bottom, top); + const int N = bottom[0]->num(); + const int K = bottom[0]->channels(); + const int H = bottom[0]->height(); + const int W = bottom[0]->width(); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); +} + +template +CuDNNReLULayer::~CuDNNReLULayer() { + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); + cudnnDestroy(this->handle_); +} + +INSTANTIATE_CLASS(CuDNNReLULayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cu new file mode 100644 index 00000000..b9d1067a --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_relu_layer.cu @@ -0,0 +1,66 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNReLULayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + // Fallback to standard Caffe for leaky ReLU. + if (ReLULayer::layer_param_.relu_param().negative_slope() != 0) { + return ReLULayer::Forward_gpu(bottom, top); + } + + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationForward(this->handle_, + CUDNN_ACTIVATION_RELU, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); +} + +template +void CuDNNReLULayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + + // Fallback to standard Caffe for leaky ReLU. + if (ReLULayer::layer_param_.relu_param().negative_slope() != 0) { + return ReLULayer::Backward_gpu(top, propagate_down, bottom); + } + + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationBackward(this->handle_, + CUDNN_ACTIVATION_RELU, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cpp new file mode 100644 index 00000000..69d749aa --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cpp @@ -0,0 +1,42 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNSigmoidLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + SigmoidLayer::LayerSetUp(bottom, top); + // initialize cuDNN + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNSigmoidLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + SigmoidLayer::Reshape(bottom, top); + const int N = bottom[0]->num(); + const int K = bottom[0]->channels(); + const int H = bottom[0]->height(); + const int W = bottom[0]->width(); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); +} + +template +CuDNNSigmoidLayer::~CuDNNSigmoidLayer() { + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); + cudnnDestroy(this->handle_); +} + +INSTANTIATE_CLASS(CuDNNSigmoidLayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cu new file mode 100644 index 00000000..010c09ce --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_sigmoid_layer.cu @@ -0,0 +1,56 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNSigmoidLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationForward(this->handle_, + CUDNN_ACTIVATION_SIGMOID, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); +} + +template +void CuDNNSigmoidLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationBackward(this->handle_, + CUDNN_ACTIVATION_SIGMOID, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cpp new file mode 100644 index 00000000..3e80c959 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cpp @@ -0,0 +1,46 @@ +#ifdef USE_CUDNN +#include +#include +#include + +#include "thrust/device_vector.h" + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNSoftmaxLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + SoftmaxLayer::LayerSetUp(bottom, top); + // Initialize CUDNN. + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNSoftmaxLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + SoftmaxLayer::Reshape(bottom, top); + int N = bottom[0]->num(); + int K = bottom[0]->channels(); + int H = bottom[0]->height(); + int W = bottom[0]->width(); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); +} + +template +CuDNNSoftmaxLayer::~CuDNNSoftmaxLayer() { + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + cudnnDestroy(handle_); +} + +INSTANTIATE_CLASS(CuDNNSoftmaxLayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cu new file mode 100644 index 00000000..b5e4a73d --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_softmax_layer.cu @@ -0,0 +1,56 @@ +#ifdef USE_CUDNN +#include +#include +#include + +#include "thrust/device_vector.h" + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNSoftmaxLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data)); +} + +template +void CuDNNSoftmaxLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + cudnn::dataType::one, + top_desc_, top_data, top_desc_, top_diff, + cudnn::dataType::zero, + bottom_desc_, bottom_diff)); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSoftmaxLayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cpp b/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cpp new file mode 100644 index 00000000..98d3d0f5 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cpp @@ -0,0 +1,42 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNTanHLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + TanHLayer::LayerSetUp(bottom, top); + // initialize cuDNN + CUDNN_CHECK(cudnnCreate(&handle_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); +} + +template +void CuDNNTanHLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + TanHLayer::Reshape(bottom, top); + const int N = bottom[0]->num(); + const int K = bottom[0]->channels(); + const int H = bottom[0]->height(); + const int W = bottom[0]->width(); + cudnn::setTensor4dDesc(&bottom_desc_, N, K, H, W); + cudnn::setTensor4dDesc(&top_desc_, N, K, H, W); +} + +template +CuDNNTanHLayer::~CuDNNTanHLayer() { + cudnnDestroyTensorDescriptor(this->bottom_desc_); + cudnnDestroyTensorDescriptor(this->top_desc_); + cudnnDestroy(this->handle_); +} + +INSTANTIATE_CLASS(CuDNNTanHLayer); + +} // namespace caffe +#endif diff --git a/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cu b/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cu new file mode 100644 index 00000000..32878646 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/cudnn_tanh_layer.cu @@ -0,0 +1,56 @@ +#ifdef USE_CUDNN +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNTanHLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationForward(this->handle_, + CUDNN_ACTIVATION_TANH, + cudnn::dataType::one, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->top_desc_, top_data)); +} + +template +void CuDNNTanHLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + Dtype alpha = 1.0; + Dtype beta = 0.0; + + CUDNN_CHECK(cudnnActivationBackward(this->handle_, + CUDNN_ACTIVATION_TANH, + cudnn::dataType::one, + this->top_desc_, top_data, this->top_desc_, top_diff, + this->bottom_desc_, bottom_data, + cudnn::dataType::zero, + this->bottom_desc_, bottom_diff)); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer); + +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/layers/data_layer.cpp b/caffe-crfrnn/src/caffe/layers/data_layer.cpp new file mode 100644 index 00000000..36968f36 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/data_layer.cpp @@ -0,0 +1,158 @@ +#include + +#include +#include + +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/dataset_factory.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" + +namespace caffe { + +template +DataLayer::~DataLayer() { + this->JoinPrefetchThread(); + // clean up the dataset resources + dataset_->close(); +} + +template +void DataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Initialize DB + dataset_ = DatasetFactory( + this->layer_param_.data_param().backend()); + const string& source = this->layer_param_.data_param().source(); + LOG(INFO) << "Opening dataset " << source; + CHECK(dataset_->open(source, Dataset::ReadOnly)); + iter_ = dataset_->begin(); + + // Check if we would need to randomly skip a few data points + if (this->layer_param_.data_param().rand_skip()) { + unsigned int skip = caffe_rng_rand() % + this->layer_param_.data_param().rand_skip(); + LOG(INFO) << "Skipping first " << skip << " data points."; + while (skip-- > 0) { + if (++iter_ == dataset_->end()) { + iter_ = dataset_->begin(); + } + } + } + // Read a data point, and use it to initialize the top blob. + CHECK(iter_ != dataset_->end()); + Datum datum = iter_->value; + + if (DecodeDatum(&datum)) { + LOG(INFO) << "Decoding Datum"; + } + // image + int crop_size = this->layer_param_.transform_param().crop_size(); + if (crop_size > 0) { + top[0]->Reshape(this->layer_param_.data_param().batch_size(), + datum.channels(), crop_size, crop_size); + this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), + datum.channels(), crop_size, crop_size); + this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size); + } else { + top[0]->Reshape( + this->layer_param_.data_param().batch_size(), datum.channels(), + datum.height(), datum.width()); + this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), + datum.channels(), datum.height(), datum.width()); + this->transformed_data_.Reshape(1, datum.channels(), + datum.height(), datum.width()); + } + LOG(INFO) << "output data size: " << top[0]->num() << "," + << top[0]->channels() << "," << top[0]->height() << "," + << top[0]->width(); + // label + if (this->output_labels_) { + top[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1); + this->prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(), + 1, 1, 1); + } +} + +template +void DataLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // Reshape on single input batches for inputs of varying dimension. + if (this->prefetch_data_.num() == 1) { + top[0]->Reshape(1, this->prefetch_data_.channels(), + this->prefetch_data_.height(), this->prefetch_data_.width()); + } +} + +// This function is used to create a thread that prefetches the data. +template +void DataLayer::InternalThreadEntry() { + CPUTimer batch_timer; + batch_timer.Start(); + double read_time = 0; + double trans_time = 0; + CPUTimer timer; + CHECK(this->prefetch_data_.count()); + CHECK(this->transformed_data_.count()); + + const int batch_size = this->layer_param_.data_param().batch_size(); + // Reshape on single input batches for inputs of varying dimension. + if (batch_size == 1) { + Datum datum = iter_->value; + this->prefetch_data_.Reshape(1, datum.channels(), + datum.height(), datum.width()); + this->transformed_data_.Reshape(1, datum.channels(), + datum.height(), datum.width()); + } + + Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_label = NULL; // suppress warnings about uninitialized variables + + if (this->output_labels_) { + top_label = this->prefetch_label_.mutable_cpu_data(); + } + for (int item_id = 0; item_id < batch_size; ++item_id) { + timer.Start(); + // get a blob + CHECK(iter_ != dataset_->end()); + const Datum& datum = iter_->value; + + cv::Mat cv_img; + if (datum.encoded()) { + cv_img = DecodeDatumToCVMat(datum); + } + read_time += timer.MicroSeconds(); + timer.Start(); + + // Apply data transformations (mirror, scale, crop...) + int offset = this->prefetch_data_.offset(item_id); + this->transformed_data_.set_cpu_data(top_data + offset); + if (datum.encoded()) { + this->data_transformer_.Transform(cv_img, &(this->transformed_data_)); + } else { + this->data_transformer_.Transform(datum, &(this->transformed_data_)); + } + if (this->output_labels_) { + top_label[item_id] = datum.label(); + } + trans_time += timer.MicroSeconds(); + // go to the next iter + ++iter_; + if (iter_ == dataset_->end()) { + iter_ = dataset_->begin(); + } + } + batch_timer.Stop(); + DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; + DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; + DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; +} + +INSTANTIATE_CLASS(DataLayer); +REGISTER_LAYER_CLASS(DATA, DataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/deconv_layer.cpp b/caffe-crfrnn/src/caffe/layers/deconv_layer.cpp new file mode 100644 index 00000000..80a1d602 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/deconv_layer.cpp @@ -0,0 +1,78 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DeconvolutionLayer::compute_output_shape() { + this->height_out_ = this->stride_h_ * (this->height_ - 1) + this->kernel_h_ + - 2 * this->pad_h_; + this->width_out_ = this->stride_w_ * (this->width_ - 1) + this->kernel_w_ + - 2 * this->pad_w_; +} + +template +void DeconvolutionLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* top_data = top[i]->mutable_cpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + this->forward_cpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void DeconvolutionLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->cpu_diff(); + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + // Gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_cpu_gemm(top_diff + top[i]->offset(n), + bottom_data + bottom[i]->offset(n), weight_diff); + } + // Gradient w.r.t. bottom data, if necessary, reusing the column buffer + // we might have just computed above. + if (propagate_down[i]) { + this->forward_cpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n), + this->param_propagate_down_[0]); + } + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(DeconvolutionLayer); +#endif + +INSTANTIATE_CLASS(DeconvolutionLayer); +REGISTER_LAYER_CLASS(DECONVOLUTION, DeconvolutionLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/deconv_layer.cu b/caffe-crfrnn/src/caffe/layers/deconv_layer.cu new file mode 100644 index 00000000..39bc4de8 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/deconv_layer.cu @@ -0,0 +1,64 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DeconvolutionLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* top_data = top[i]->mutable_gpu_data(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_gemm(bottom_data + bottom[i]->offset(n), weight, + top_data + top[i]->offset(n)); + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->gpu_data(); + this->forward_gpu_bias(top_data + top[i]->offset(n), bias); + } + } + } +} + +template +void DeconvolutionLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* weight = this->blobs_[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + for (int i = 0; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_gpu_bias(bias_diff, top_diff + top[i]->offset(n)); + } + } + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + // gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { + this->weight_gpu_gemm(top_diff + top[i]->offset(n), + bottom_data + bottom[i]->offset(n), weight_diff); + } + // gradient w.r.t. bottom data, if necessary. + if (propagate_down[i]) { + this->forward_gpu_gemm(top_diff + top[i]->offset(n), weight, + bottom_diff + bottom[i]->offset(n)); + } + } + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(DeconvolutionLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/dropout_layer.cpp b/caffe-crfrnn/src/caffe/layers/dropout_layer.cpp new file mode 100644 index 00000000..8c8936a7 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/dropout_layer.cpp @@ -0,0 +1,77 @@ +// TODO (sergeyk): effect should not be dependent on phase. wasted memcpy. + +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DropoutLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + threshold_ = this->layer_param_.dropout_param().dropout_ratio(); + DCHECK(threshold_ > 0.); + DCHECK(threshold_ < 1.); + scale_ = 1. / (1. - threshold_); + uint_thres_ = static_cast(UINT_MAX * threshold_); +} + +template +void DropoutLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::Reshape(bottom, top); + // Set up the cache for random number generation + rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + +template +void DropoutLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + unsigned int* mask = rand_vec_.mutable_cpu_data(); + const int count = bottom[0]->count(); + if (Caffe::phase() == Caffe::TRAIN) { + // Create random numbers + caffe_rng_bernoulli(count, 1. - threshold_, mask); + for (int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * mask[i] * scale_; + } + } else { + caffe_copy(bottom[0]->count(), bottom_data, top_data); + } +} + +template +void DropoutLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + if (Caffe::phase() == Caffe::TRAIN) { + const unsigned int* mask = rand_vec_.cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * mask[i] * scale_; + } + } else { + caffe_copy(top[0]->count(), top_diff, bottom_diff); + } + } +} + + +#ifdef CPU_ONLY +STUB_GPU(DropoutLayer); +#endif + +INSTANTIATE_CLASS(DropoutLayer); +REGISTER_LAYER_CLASS(DROPOUT, DropoutLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/dropout_layer.cu b/caffe-crfrnn/src/caffe/layers/dropout_layer.cu new file mode 100644 index 00000000..df13d8ec --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/dropout_layer.cu @@ -0,0 +1,77 @@ +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + + +template +__global__ void DropoutForward(const int n, const Dtype* in, + const unsigned int* mask, const unsigned int threshold, const float scale, + Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = in[index] * (mask[index] > threshold) * scale; + } +} + +template +void DropoutLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + if (Caffe::phase() == Caffe::TRAIN) { + unsigned int* mask = + static_cast(rand_vec_.mutable_gpu_data()); + caffe_gpu_rng_uniform(count, mask); + // set thresholds + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutForward<<>>( + count, bottom_data, mask, uint_thres_, scale_, top_data); + CUDA_POST_KERNEL_CHECK; + } else { + caffe_copy(count, bottom_data, top_data); + } +} + +template +__global__ void DropoutBackward(const int n, const Dtype* in_diff, + const unsigned int* mask, const unsigned int threshold, const float scale, + Dtype* out_diff) { + CUDA_KERNEL_LOOP(index, n) { + out_diff[index] = in_diff[index] * scale * (mask[index] > threshold); + } +} + +template +void DropoutLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + if (Caffe::phase() == Caffe::TRAIN) { + const unsigned int* mask = + static_cast(rand_vec_.gpu_data()); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + DropoutBackward<<>>( + count, top_diff, mask, uint_thres_, scale_, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } else { + caffe_copy(top[0]->count(), top_diff, bottom_diff); + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(DropoutLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/dummy_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/dummy_data_layer.cpp new file mode 100644 index 00000000..15cf5a58 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/dummy_data_layer.cpp @@ -0,0 +1,97 @@ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void DummyDataLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + const int num_top = top.size(); + const DummyDataParameter& param = this->layer_param_.dummy_data_param(); + const int num_data_filler = param.data_filler_size(); + CHECK(num_data_filler == 0 || num_data_filler == 1 || + num_data_filler == num_top) + << "Number of data fillers must be 0, 1 or equal to the number of tops: " + << num_top << "; you specified " << num_data_filler << " data fillers."; + CHECK(param.num_size() == 1 || param.num_size() == num_top) + << "Must specify either a single (1) 'num' or one for each top blob " + << "(" << num_top << "); you specified " << param.num_size() << "."; + CHECK(param.channels_size() == 1 || param.channels_size() == num_top) + << "Must specify either a single (1) 'channels' or one for each top blob " + << "(" << num_top << "); you specified " << param.channels_size() << "."; + CHECK(param.height_size() == 1 || param.height_size() == num_top) + << "Must specify either a single (1) 'height' or one for each top blob " + << "(" << num_top << "); you specified " << param.height_size() << "."; + CHECK(param.width_size() == 1 || param.width_size() == num_top) + << "Must specify either a single (1) 'width' or one for each top blob " + << "(" << num_top << "); you specified " << param.width_size() << "."; + // refill_[i] tells Forward i whether or not to actually refill top Blob i. + // If refill_[i] is false, Forward does nothing for Blob i. We use this to + // avoid wastefully refilling "constant" Blobs in every forward pass. + // We first fill refill_ in with the INVERSE of its final values. + // The first time we run Forward from the LayerSetUp method, we'll fill only + // Blobs for which refill_ is normally false. These Blobs will never be + // filled again. + refill_.clear(); + fillers_.clear(); + if (num_data_filler <= 1) { + FillerParameter filler_param; + if (num_data_filler == 0) { + filler_param.set_type("constant"); + filler_param.set_value(0); + } else { + filler_param.CopyFrom(param.data_filler(0)); + } + // Refill on each iteration iff not using a constant filler, + // but use the inverse of this rule for the first run. + refill_.resize(1); + refill_[0] = (strcmp(filler_param.type().c_str(), "constant") == 0); + fillers_.resize(1); + fillers_[0].reset(GetFiller(filler_param)); + } else { + refill_.resize(num_top); + fillers_.resize(num_top); + for (int i = 0; i < num_top; ++i) { + fillers_[i].reset(GetFiller(param.data_filler(i))); + // Refill on each iteration iff not using a constant filler, + // but use the inverse of this rule for the first run. + refill_[i] = + (strcmp(param.data_filler(i).type().c_str(), "constant") == 0); + } + } + for (int i = 0; i < num_top; ++i) { + const int num = (param.num_size() == 1) ? param.num(0) : param.num(i); + const int channels = + (param.channels_size() == 1) ? param.channels(0) : param.channels(i); + const int height = + (param.height_size() == 1) ? param.height(0) : param.height(i); + const int width = + (param.width_size() == 1) ? param.width(0) : param.width(i); + top[i]->Reshape(num, channels, height, width); + } + // Run Forward once, with refill_ inverted, to fill the constant Blobs. + this->Forward(bottom, top); + // Invert the inverted refill_ values to refill the desired (non-constant) + // Blobs in every usual forward pass. + for (int i = 0; i < refill_.size(); ++i) { + refill_[i] = !refill_[i]; + } +} + +template +void DummyDataLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + for (int i = 0; i < top.size(); ++i) { + const int filler_id = (fillers_.size() > 1) ? i : 0; + if (refill_[filler_id]) { + fillers_[filler_id]->Fill(top[i]); + } + } +} + +INSTANTIATE_CLASS(DummyDataLayer); +REGISTER_LAYER_CLASS(DUMMY_DATA, DummyDataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/eltwise_layer.cpp b/caffe-crfrnn/src/caffe/layers/eltwise_layer.cpp new file mode 100644 index 00000000..0c239f4a --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/eltwise_layer.cpp @@ -0,0 +1,167 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void EltwiseLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + CHECK(this->layer_param().eltwise_param().coeff_size() == 0 + || this->layer_param().eltwise_param().coeff_size() == bottom.size()) << + "Eltwise Layer takes one coefficient per bottom blob."; + CHECK(!(this->layer_param().eltwise_param().operation() + == EltwiseParameter_EltwiseOp_PROD + && this->layer_param().eltwise_param().coeff_size())) << + "Eltwise layer only takes coefficients for summation."; + op_ = this->layer_param_.eltwise_param().operation(); + // Blob-wise coefficients for the elementwise operation. + coeffs_ = vector(bottom.size(), 1); + if (this->layer_param().eltwise_param().coeff_size()) { + for (int i = 0; i < bottom.size(); ++i) { + coeffs_[i] = this->layer_param().eltwise_param().coeff(i); + } + } + stable_prod_grad_ = this->layer_param_.eltwise_param().stable_prod_grad(); +} + +template +void EltwiseLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + const int num = bottom[0]->num(); + const int channels = bottom[0]->channels(); + const int height = bottom[0]->height(); + const int width = bottom[0]->width(); + for (int i = 1; i < bottom.size(); ++i) { + CHECK_EQ(num, bottom[i]->num()); + CHECK_EQ(channels, bottom[i]->channels()); + CHECK_EQ(height, bottom[i]->height()); + CHECK_EQ(width, bottom[i]->width()); + } + top[0]->Reshape(num, channels, height, width); + // If max operation, we will initialize the vector index part. + if (this->layer_param_.eltwise_param().operation() == + EltwiseParameter_EltwiseOp_MAX && top.size() == 1) { + max_idx_.Reshape(bottom[0]->num(), channels, height, width); + } +} + +template +void EltwiseLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + int* mask = NULL; + const Dtype* bottom_data_a = NULL; + const Dtype* bottom_data_b = NULL; + const int count = top[0]->count(); + Dtype* top_data = top[0]->mutable_cpu_data(); + switch (op_) { + case EltwiseParameter_EltwiseOp_PROD: + caffe_mul(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), top_data); + for (int i = 2; i < bottom.size(); ++i) { + caffe_mul(count, top_data, bottom[i]->cpu_data(), top_data); + } + break; + case EltwiseParameter_EltwiseOp_SUM: + caffe_set(count, Dtype(0), top_data); + // TODO(shelhamer) does BLAS optimize to sum for coeff = 1? + for (int i = 0; i < bottom.size(); ++i) { + caffe_axpy(count, coeffs_[i], bottom[i]->cpu_data(), top_data); + } + break; + case EltwiseParameter_EltwiseOp_MAX: + // Initialize + mask = max_idx_.mutable_cpu_data(); + caffe_set(count, -1, mask); + caffe_set(count, Dtype(-FLT_MAX), top_data); + // bottom 0 & 1 + bottom_data_a = bottom[0]->cpu_data(); + bottom_data_b = bottom[1]->cpu_data(); + for (int idx = 0; idx < count; ++idx) { + if (bottom_data_a[idx] > bottom_data_b[idx]) { + top_data[idx] = bottom_data_a[idx]; // maxval + mask[idx] = 0; // maxid + } else { + top_data[idx] = bottom_data_b[idx]; // maxval + mask[idx] = 1; // maxid + } + } + // bottom 2++ + for (int blob_idx = 2; blob_idx < bottom.size(); ++blob_idx) { + bottom_data_b = bottom[blob_idx]->cpu_data(); + for (int idx = 0; idx < count; ++idx) { + if (bottom_data_b[idx] > top_data[idx]) { + top_data[idx] = bottom_data_b[idx]; // maxval + mask[idx] = blob_idx; // maxid + } + } + } + break; + default: + LOG(FATAL) << "Unknown elementwise operation."; + } +} + +template +void EltwiseLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const int* mask = NULL; + const int count = top[0]->count(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + for (int i = 0; i < bottom.size(); ++i) { + if (propagate_down[i]) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); + switch (op_) { + case EltwiseParameter_EltwiseOp_PROD: + if (stable_prod_grad_) { + bool initialized = false; + for (int j = 0; j < bottom.size(); ++j) { + if (i == j) { continue; } + if (!initialized) { + caffe_copy(count, bottom[j]->cpu_data(), bottom_diff); + initialized = true; + } else { + caffe_mul(count, bottom[j]->cpu_data(), bottom_diff, + bottom_diff); + } + } + } else { + caffe_div(count, top_data, bottom_data, bottom_diff); + } + caffe_mul(count, bottom_diff, top_diff, bottom_diff); + break; + case EltwiseParameter_EltwiseOp_SUM: + if (coeffs_[i] == Dtype(1)) { + caffe_copy(count, top_diff, bottom_diff); + } else { + caffe_cpu_scale(count, coeffs_[i], top_diff, bottom_diff); + } + break; + case EltwiseParameter_EltwiseOp_MAX: + mask = max_idx_.cpu_data(); + for (int index = 0; index < count; ++index) { + Dtype gradient = 0; + if (mask[index] == i) { + gradient += top_diff[index]; + } + bottom_diff[index] = gradient; + } + break; + default: + LOG(FATAL) << "Unknown elementwise operation."; + } + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(EltwiseLayer); +#endif + +INSTANTIATE_CLASS(EltwiseLayer); +REGISTER_LAYER_CLASS(ELTWISE, EltwiseLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/eltwise_layer.cu b/caffe-crfrnn/src/caffe/layers/eltwise_layer.cu new file mode 100644 index 00000000..2247870d --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/eltwise_layer.cu @@ -0,0 +1,135 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void MaxForward(const int nthreads, const Dtype* bottom_data_a, + const Dtype* bottom_data_b, const int blob_idx, Dtype* top_data, + int* mask) { + CUDA_KERNEL_LOOP(index, nthreads) { + Dtype maxval = -FLT_MAX; + int maxidx = -1; + if (bottom_data_a[index] > bottom_data_b[index]) { + // only update for very first bottom_data blob (blob_idx == 0) + if (blob_idx == 0) { + maxval = bottom_data_a[index]; + top_data[index] = maxval; + maxidx = blob_idx; + mask[index] = maxidx; + } + } else { + maxval = bottom_data_b[index]; + top_data[index] = maxval; + maxidx = blob_idx + 1; + mask[index] = maxidx; + } + } +} + +template +void EltwiseLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + int* mask = NULL; + const int count = top[0]->count(); + Dtype* top_data = top[0]->mutable_gpu_data(); + switch (op_) { + case EltwiseParameter_EltwiseOp_PROD: + caffe_gpu_mul(count, bottom[0]->gpu_data(), bottom[1]->gpu_data(), + top_data); + for (int i = 2; i < bottom.size(); ++i) { + caffe_gpu_mul(count, top_data, bottom[i]->gpu_data(), top_data); + } + break; + case EltwiseParameter_EltwiseOp_SUM: + caffe_gpu_set(count, Dtype(0.), top_data); + // TODO(shelhamer) does cuBLAS optimize to sum for coeff = 1? + for (int i = 0; i < bottom.size(); ++i) { + caffe_gpu_axpy(count, coeffs_[i], bottom[i]->gpu_data(), top_data); + } + break; + case EltwiseParameter_EltwiseOp_MAX: + mask = max_idx_.mutable_gpu_data(); + // NOLINT_NEXT_LINE(whitespace/operators) + MaxForward <<>>( + count, bottom[0]->gpu_data(), bottom[1]->gpu_data(), 0, top_data, mask); + for (int i = 2; i < bottom.size(); ++i) { + // NOLINT_NEXT_LINE(whitespace/operators) + MaxForward<<>>( + count, top_data, bottom[i]->gpu_data(), i-1, top_data, mask); + } + break; + default: + LOG(FATAL) << "Unknown elementwise operation."; + } +} + +template +__global__ void MaxBackward(const int nthreads, const Dtype* top_diff, + const int blob_idx, const int* mask, Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + Dtype gradient = 0; + if (mask[index] == blob_idx) { + gradient += top_diff[index]; + } + bottom_diff[index] = gradient; + } +} + +template +void EltwiseLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const int* mask = NULL; + const int count = top[0]->count(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + for (int i = 0; i < bottom.size(); ++i) { + if (propagate_down[i]) { + const Dtype* bottom_data = bottom[i]->gpu_data(); + Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); + switch (op_) { + case EltwiseParameter_EltwiseOp_PROD: + if (stable_prod_grad_) { + bool initialized = false; + for (int j = 0; j < bottom.size(); ++j) { + if (i == j) { continue; } + if (!initialized) { + caffe_copy(count, bottom[j]->gpu_data(), bottom_diff); + initialized = true; + } else { + caffe_gpu_mul(count, bottom[j]->gpu_data(), bottom_diff, + bottom_diff); + } + } + } else { + caffe_gpu_div(count, top_data, bottom_data, bottom_diff); + } + caffe_gpu_mul(count, bottom_diff, top_diff, bottom_diff); + break; + case EltwiseParameter_EltwiseOp_SUM: + if (coeffs_[i] == Dtype(1.)) { + caffe_copy(count, top_diff, bottom_diff); + } else { + caffe_gpu_scale(count, coeffs_[i], top_diff, bottom_diff); + } + break; + case EltwiseParameter_EltwiseOp_MAX: + mask = max_idx_.gpu_data(); + MaxBackward // NOLINT_NEXT_LINE(whitespace/operators) + <<>>( + count, top_diff, i, mask, bottom_diff); + break; + default: + LOG(FATAL) << "Unknown elementwise operation."; + } + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(EltwiseLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cpp new file mode 100644 index 00000000..d965027f --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cpp @@ -0,0 +1,58 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void EuclideanLossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[1]->height()); + CHECK_EQ(bottom[0]->width(), bottom[1]->width()); + diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + +template +void EuclideanLossLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + int count = bottom[0]->count(); + caffe_sub( + count, + bottom[0]->cpu_data(), + bottom[1]->cpu_data(), + diff_.mutable_cpu_data()); + Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data()); + Dtype loss = dot / bottom[0]->num() / Dtype(2); + top[0]->mutable_cpu_data()[0] = loss; +} + +template +void EuclideanLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num(); + caffe_cpu_axpby( + bottom[i]->count(), // count + alpha, // alpha + diff_.cpu_data(), // a + Dtype(0), // beta + bottom[i]->mutable_cpu_diff()); // b + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(EuclideanLossLayer); +#endif + +INSTANTIATE_CLASS(EuclideanLossLayer); +REGISTER_LAYER_CLASS(EUCLIDEAN_LOSS, EuclideanLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cu b/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cu new file mode 100644 index 00000000..5b1de3ad --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/euclidean_loss_layer.cu @@ -0,0 +1,44 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void EuclideanLossLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + int count = bottom[0]->count(); + caffe_gpu_sub( + count, + bottom[0]->gpu_data(), + bottom[1]->gpu_data(), + diff_.mutable_gpu_data()); + Dtype dot; + caffe_gpu_dot(count, diff_.gpu_data(), diff_.gpu_data(), &dot); + Dtype loss = dot / bottom[0]->num() / Dtype(2); + top[0]->mutable_cpu_data()[0] = loss; +} + +template +void EuclideanLossLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num(); + caffe_gpu_axpby( + bottom[i]->count(), // count + alpha, // alpha + diff_.gpu_data(), // a + Dtype(0), // beta + bottom[i]->mutable_gpu_diff()); // b + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(EuclideanLossLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/exp_layer.cpp b/caffe-crfrnn/src/caffe/layers/exp_layer.cpp new file mode 100644 index 00000000..92cb5deb --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/exp_layer.cpp @@ -0,0 +1,68 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ExpLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + const Dtype base = this->layer_param_.exp_param().base(); + if (base != Dtype(-1)) { + CHECK_GT(base, 0) << "base must be strictly positive."; + } + // If base == -1, interpret the base as e and set log_base = 1 exactly. + // Otherwise, calculate its log explicitly. + const Dtype log_base = (base == Dtype(-1)) ? Dtype(1) : log(base); + CHECK(!isnan(log_base)) + << "NaN result: log(base) = log(" << base << ") = " << log_base; + CHECK(!isinf(log_base)) + << "Inf result: log(base) = log(" << base << ") = " << log_base; + const Dtype input_scale = this->layer_param_.exp_param().scale(); + const Dtype input_shift = this->layer_param_.exp_param().shift(); + inner_scale_ = log_base * input_scale; + outer_scale_ = (input_shift == Dtype(0)) ? Dtype(1) : pow(base, input_shift); +} + +template +void ExpLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const int count = bottom[0]->count(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + if (inner_scale_ == Dtype(1)) { + caffe_exp(count, bottom_data, top_data); + } else { + caffe_cpu_scale(count, inner_scale_, bottom_data, top_data); + caffe_exp(count, top_data, top_data); + } + if (outer_scale_ != Dtype(1)) { + caffe_scal(count, outer_scale_, top_data); + } +} + +template +void ExpLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + const int count = bottom[0]->count(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_mul(count, top_data, top_diff, bottom_diff); + if (inner_scale_ != Dtype(1)) { + caffe_scal(count, inner_scale_, bottom_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(ExpLayer); +#endif + +INSTANTIATE_CLASS(ExpLayer); +REGISTER_LAYER_CLASS(EXP, ExpLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/exp_layer.cu b/caffe-crfrnn/src/caffe/layers/exp_layer.cu new file mode 100644 index 00000000..2d75d8dd --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/exp_layer.cu @@ -0,0 +1,44 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ExpLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const int count = bottom[0]->count(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + if (inner_scale_ == Dtype(1)) { + caffe_gpu_exp(count, bottom_data, top_data); + } else { + caffe_gpu_scale(count, inner_scale_, bottom_data, top_data); + caffe_gpu_exp(count, top_data, top_data); + } + if (outer_scale_ != Dtype(1)) { + caffe_gpu_scal(count, outer_scale_, top_data); + } +} + +template +void ExpLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + const int count = bottom[0]->count(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + caffe_gpu_mul(count, top_data, top_diff, bottom_diff); + if (inner_scale_ != Dtype(1)) { + caffe_gpu_scal(count, inner_scale_, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(ExpLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/flatten_layer.cpp b/caffe-crfrnn/src/caffe/layers/flatten_layer.cpp new file mode 100644 index 00000000..ec43caba --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/flatten_layer.cpp @@ -0,0 +1,38 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void FlattenLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + int channels_out = bottom[0]->channels() * bottom[0]->height() + * bottom[0]->width(); + top[0]->Reshape(bottom[0]->num(), channels_out, 1, 1); + count_ = bottom[0]->num() * channels_out; + CHECK_EQ(count_, bottom[0]->count()); + CHECK_EQ(count_, top[0]->count()); +} + +template +void FlattenLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + top[0]->ShareData(*bottom[0]); +} + +template +void FlattenLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + bottom[0]->ShareDiff(*top[0]); +} + +#ifdef CPU_ONLY +STUB_GPU(FlattenLayer); +#endif + +INSTANTIATE_CLASS(FlattenLayer); +REGISTER_LAYER_CLASS(FLATTEN, FlattenLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/flatten_layer.cu b/caffe-crfrnn/src/caffe/layers/flatten_layer.cu new file mode 100644 index 00000000..42abdad4 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/flatten_layer.cu @@ -0,0 +1,23 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void FlattenLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + top[0]->ShareData(*bottom[0]); +} + +template +void FlattenLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + bottom[0]->ShareDiff(*top[0]); +} + +INSTANTIATE_LAYER_GPU_FUNCS(FlattenLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cpp new file mode 100644 index 00000000..706c4a18 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cpp @@ -0,0 +1,125 @@ +/* +TODO: +- load file in a separate thread ("prefetch") +- can be smarter about the memcpy call instead of doing it row-by-row + :: use util functions caffe_copy, and Blob->offset() + :: don't forget to update hdf5_daa_layer.cu accordingly +- add ability to shuffle filenames if flag is set +*/ +#include // NOLINT(readability/streams) +#include +#include + +#include "hdf5.h" +#include "hdf5_hl.h" +#include "stdint.h" + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +HDF5DataLayer::~HDF5DataLayer() { } + +// Load data and label from HDF5 filename into the class property blobs. +template +void HDF5DataLayer::LoadHDF5FileData(const char* filename) { + DLOG(INFO) << "Loading HDF5 file: " << filename; + hid_t file_id = H5Fopen(filename, H5F_ACC_RDONLY, H5P_DEFAULT); + if (file_id < 0) { + LOG(FATAL) << "Failed opening HDF5 file: " << filename; + } + + int top_size = this->layer_param_.top_size(); + hdf_blobs_.resize(top_size); + + const int MIN_DATA_DIM = 1; + const int MAX_DATA_DIM = 4; + + for (int i = 0; i < top_size; ++i) { + hdf_blobs_[i] = shared_ptr >(new Blob()); + hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(), + MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get()); + } + + herr_t status = H5Fclose(file_id); + CHECK_GE(status, 0) << "Failed to close HDF5 file: " << filename; + + // MinTopBlobs==1 guarantees at least one top blob + int num = hdf_blobs_[0]->num(); + for (int i = 1; i < top_size; ++i) { + CHECK_EQ(hdf_blobs_[i]->num(), num); + } + DLOG(INFO) << "Successully loaded " << hdf_blobs_[0]->num() << " rows"; +} + +template +void HDF5DataLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Read the source to parse the filenames. + const string& source = this->layer_param_.hdf5_data_param().source(); + LOG(INFO) << "Loading list of HDF5 filenames from: " << source; + hdf_filenames_.clear(); + std::ifstream source_file(source.c_str()); + if (source_file.is_open()) { + std::string line; + while (source_file >> line) { + hdf_filenames_.push_back(line); + } + } else { + LOG(FATAL) << "Failed to open source file: " << source; + } + source_file.close(); + num_files_ = hdf_filenames_.size(); + current_file_ = 0; + LOG(INFO) << "Number of HDF5 files: " << num_files_; + CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in " + << source; + + // Load the first HDF5 file and initialize the line counter. + LoadHDF5FileData(hdf_filenames_[current_file_].c_str()); + current_row_ = 0; + + // Reshape blobs. + const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); + const int top_size = this->layer_param_.top_size(); + for (int i = 0; i < top_size; ++i) { + top[i]->Reshape(batch_size, hdf_blobs_[i]->channels(), + hdf_blobs_[i]->height(), hdf_blobs_[i]->width()); + } +} + +template +void HDF5DataLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); + for (int i = 0; i < batch_size; ++i, ++current_row_) { + if (current_row_ == hdf_blobs_[0]->num()) { + if (num_files_ > 1) { + ++current_file_; + if (current_file_ == num_files_) { + current_file_ = 0; + DLOG(INFO) << "Looping around to first file."; + } + LoadHDF5FileData(hdf_filenames_[current_file_].c_str()); + } + current_row_ = 0; + } + for (int j = 0; j < this->layer_param_.top_size(); ++j) { + int data_dim = top[j]->count() / top[j]->num(); + caffe_copy(data_dim, + &hdf_blobs_[j]->cpu_data()[current_row_ * data_dim], + &top[j]->mutable_cpu_data()[i * data_dim]); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU_FORWARD(HDF5DataLayer, Forward); +#endif + +INSTANTIATE_CLASS(HDF5DataLayer); +REGISTER_LAYER_CLASS(HDF5_DATA, HDF5DataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cu b/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cu new file mode 100644 index 00000000..02e3821d --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/hdf5_data_layer.cu @@ -0,0 +1,46 @@ +/* +TODO: +- only load parts of the file, in accordance with a prototxt param "max_mem" +*/ + +#include +#include +#include + +#include "hdf5.h" +#include "hdf5_hl.h" + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void HDF5DataLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); + for (int i = 0; i < batch_size; ++i, ++current_row_) { + if (current_row_ == hdf_blobs_[0]->num()) { + if (num_files_ > 1) { + current_file_ += 1; + if (current_file_ == num_files_) { + current_file_ = 0; + DLOG(INFO) << "Looping around to first file."; + } + LoadHDF5FileData(hdf_filenames_[current_file_].c_str()); + } + current_row_ = 0; + } + for (int j = 0; j < this->layer_param_.top_size(); ++j) { + int data_dim = top[j]->count() / top[j]->num(); + caffe_copy(data_dim, + &hdf_blobs_[j]->cpu_data()[current_row_ * data_dim], + &top[j]->mutable_gpu_data()[i * data_dim]); + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(HDF5DataLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cpp b/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cpp new file mode 100644 index 00000000..4a72a18a --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cpp @@ -0,0 +1,74 @@ +#include + +#include "hdf5.h" +#include "hdf5_hl.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +HDF5OutputLayer::HDF5OutputLayer(const LayerParameter& param) + : Layer(param), + file_name_(param.hdf5_output_param().file_name()) { + /* create a HDF5 file */ + file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, + H5P_DEFAULT); + CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_; +} + +template +HDF5OutputLayer::~HDF5OutputLayer() { + herr_t status = H5Fclose(file_id_); + CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_; +} + +template +void HDF5OutputLayer::SaveBlobs() { + // TODO: no limit on the number of blobs + LOG(INFO) << "Saving HDF5 file " << file_name_; + CHECK_EQ(data_blob_.num(), label_blob_.num()) << + "data blob and label blob must have the same batch size"; + hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_); + hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_); + LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows"; +} + +template +void HDF5OutputLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + CHECK_GE(bottom.size(), 2); + CHECK_EQ(bottom[0]->num(), bottom[1]->num()); + data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(), + bottom[1]->height(), bottom[1]->width()); + const int data_datum_dim = bottom[0]->count() / bottom[0]->num(); + const int label_datum_dim = bottom[1]->count() / bottom[1]->num(); + + for (int i = 0; i < bottom[0]->num(); ++i) { + caffe_copy(data_datum_dim, &bottom[0]->cpu_data()[i * data_datum_dim], + &data_blob_.mutable_cpu_data()[i * data_datum_dim]); + caffe_copy(label_datum_dim, &bottom[1]->cpu_data()[i * label_datum_dim], + &label_blob_.mutable_cpu_data()[i * label_datum_dim]); + } + SaveBlobs(); +} + +template +void HDF5OutputLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + return; +} + +#ifdef CPU_ONLY +STUB_GPU(HDF5OutputLayer); +#endif + +INSTANTIATE_CLASS(HDF5OutputLayer); +REGISTER_LAYER_CLASS(HDF5_OUTPUT, HDF5OutputLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cu b/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cu new file mode 100644 index 00000000..ae497c34 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/hdf5_output_layer.cu @@ -0,0 +1,43 @@ +#include + +#include "hdf5.h" +#include "hdf5_hl.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void HDF5OutputLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + CHECK_GE(bottom.size(), 2); + CHECK_EQ(bottom[0]->num(), bottom[1]->num()); + data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(), + bottom[1]->height(), bottom[1]->width()); + const int data_datum_dim = bottom[0]->count() / bottom[0]->num(); + const int label_datum_dim = bottom[1]->count() / bottom[1]->num(); + + for (int i = 0; i < bottom[0]->num(); ++i) { + caffe_copy(data_datum_dim, &bottom[0]->gpu_data()[i * data_datum_dim], + &data_blob_.mutable_cpu_data()[i * data_datum_dim]); + caffe_copy(label_datum_dim, &bottom[1]->gpu_data()[i * label_datum_dim], + &label_blob_.mutable_cpu_data()[i * label_datum_dim]); + } + SaveBlobs(); +} + +template +void HDF5OutputLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + return; +} + +INSTANTIATE_LAYER_GPU_FUNCS(HDF5OutputLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/hinge_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/hinge_loss_layer.cpp new file mode 100644 index 00000000..4dfafcc8 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/hinge_loss_layer.cpp @@ -0,0 +1,81 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void HingeLossLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const Dtype* label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int count = bottom[0]->count(); + int dim = count / num; + + caffe_copy(count, bottom_data, bottom_diff); + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] *= -1; + } + for (int i = 0; i < num; ++i) { + for (int j = 0; j < dim; ++j) { + bottom_diff[i * dim + j] = std::max( + Dtype(0), 1 + bottom_diff[i * dim + j]); + } + } + Dtype* loss = top[0]->mutable_cpu_data(); + switch (this->layer_param_.hinge_loss_param().norm()) { + case HingeLossParameter_Norm_L1: + loss[0] = caffe_cpu_asum(count, bottom_diff) / num; + break; + case HingeLossParameter_Norm_L2: + loss[0] = caffe_cpu_dot(count, bottom_diff, bottom_diff) / num; + break; + default: + LOG(FATAL) << "Unknown Norm"; + } +} + +template +void HingeLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const Dtype* label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int count = bottom[0]->count(); + int dim = count / num; + + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] *= -1; + } + + const Dtype loss_weight = top[0]->cpu_diff()[0]; + switch (this->layer_param_.hinge_loss_param().norm()) { + case HingeLossParameter_Norm_L1: + caffe_cpu_sign(count, bottom_diff, bottom_diff); + caffe_scal(count, loss_weight / num, bottom_diff); + break; + case HingeLossParameter_Norm_L2: + caffe_scal(count, loss_weight * 2 / num, bottom_diff); + break; + default: + LOG(FATAL) << "Unknown Norm"; + } + } +} + +INSTANTIATE_CLASS(HingeLossLayer); +REGISTER_LAYER_CLASS(HINGE_LOSS, HingeLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/im2col_layer.cpp b/caffe-crfrnn/src/caffe/layers/im2col_layer.cpp new file mode 100644 index 00000000..2c4bb902 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/im2col_layer.cpp @@ -0,0 +1,92 @@ +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void Im2colLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + ConvolutionParameter conv_param = this->layer_param_.convolution_param(); + CHECK(!conv_param.has_kernel_size() != + !(conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; + CHECK(conv_param.has_kernel_size() || + (conv_param.has_kernel_h() && conv_param.has_kernel_w())) + << "For non-square filters both kernel_h and kernel_w are required."; + CHECK((!conv_param.has_pad() && conv_param.has_pad_h() + && conv_param.has_pad_w()) + || (!conv_param.has_pad_h() && !conv_param.has_pad_w())) + << "pad is pad OR pad_h and pad_w are required."; + CHECK((!conv_param.has_stride() && conv_param.has_stride_h() + && conv_param.has_stride_w()) + || (!conv_param.has_stride_h() && !conv_param.has_stride_w())) + << "Stride is stride OR stride_h and stride_w are required."; + if (conv_param.has_kernel_size()) { + kernel_h_ = kernel_w_ = conv_param.kernel_size(); + } else { + kernel_h_ = conv_param.kernel_h(); + kernel_w_ = conv_param.kernel_w(); + } + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + if (!conv_param.has_pad_h()) { + pad_h_ = pad_w_ = conv_param.pad(); + } else { + pad_h_ = conv_param.pad_h(); + pad_w_ = conv_param.pad_w(); + } + if (!conv_param.has_stride_h()) { + stride_h_ = stride_w_ = conv_param.stride(); + } else { + stride_h_ = conv_param.stride_h(); + stride_w_ = conv_param.stride_w(); + } +} + +template +void Im2colLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + top[0]->Reshape( + bottom[0]->num(), channels_ * kernel_h_ * kernel_w_, + (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1, + (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1); +} + +template +void Im2colLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + for (int n = 0; n < bottom[0]->num(); ++n) { + im2col_cpu(bottom_data + bottom[0]->offset(n), channels_, height_, + width_, kernel_h_, kernel_w_, pad_h_, pad_w_, + stride_h_, stride_w_, top_data + top[0]->offset(n)); + } +} + +template +void Im2colLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + for (int n = 0; n < top[0]->num(); ++n) { + col2im_cpu(top_diff + top[0]->offset(n), channels_, height_, width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, + stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + } +} + +#ifdef CPU_ONLY +STUB_GPU(Im2colLayer); +#endif + +INSTANTIATE_CLASS(Im2colLayer); +REGISTER_LAYER_CLASS(IM2COL, Im2colLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/im2col_layer.cu b/caffe-crfrnn/src/caffe/layers/im2col_layer.cu new file mode 100644 index 00000000..9c338b14 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/im2col_layer.cu @@ -0,0 +1,37 @@ +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void Im2colLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + for (int n = 0; n < bottom[0]->num(); ++n) { + im2col_gpu(bottom_data + bottom[0]->offset(n), channels_, height_, + width_, kernel_h_, kernel_w_, pad_h_, pad_w_, + stride_h_, stride_w_, top_data + top[0]->offset(n)); + } +} + +template +void Im2colLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + for (int n = 0; n < top[0]->num(); ++n) { + col2im_gpu(top_diff + top[0]->offset(n), channels_, height_, width_, + kernel_h_, kernel_w_, pad_h_, pad_w_, + stride_h_, stride_w_, bottom_diff + bottom[0]->offset(n)); + } +} + + +INSTANTIATE_LAYER_GPU_FUNCS(Im2colLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/image_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/image_data_layer.cpp new file mode 100644 index 00000000..50997a23 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/image_data_layer.cpp @@ -0,0 +1,151 @@ +#include // NOLINT(readability/streams) +#include // NOLINT(readability/streams) +#include +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" + +namespace caffe { + +template +ImageDataLayer::~ImageDataLayer() { + this->JoinPrefetchThread(); +} + +template +void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + const int new_height = this->layer_param_.image_data_param().new_height(); + const int new_width = this->layer_param_.image_data_param().new_width(); + const bool is_color = this->layer_param_.image_data_param().is_color(); + string root_folder = this->layer_param_.image_data_param().root_folder(); + + CHECK((new_height == 0 && new_width == 0) || + (new_height > 0 && new_width > 0)) << "Current implementation requires " + "new_height and new_width to be set at the same time."; + // Read the file with filenames and labels + const string& source = this->layer_param_.image_data_param().source(); + LOG(INFO) << "Opening file " << source; + std::ifstream infile(source.c_str()); + string filename; + int label; + while (infile >> filename >> label) { + lines_.push_back(std::make_pair(filename, label)); + } + + if (this->layer_param_.image_data_param().shuffle()) { + // randomly shuffle data + LOG(INFO) << "Shuffling data"; + const unsigned int prefetch_rng_seed = caffe_rng_rand(); + prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); + ShuffleImages(); + } + LOG(INFO) << "A total of " << lines_.size() << " images."; + + lines_id_ = 0; + // Check if we would need to randomly skip a few data points + if (this->layer_param_.image_data_param().rand_skip()) { + unsigned int skip = caffe_rng_rand() % + this->layer_param_.image_data_param().rand_skip(); + LOG(INFO) << "Skipping first " << skip << " data points."; + CHECK_GT(lines_.size(), skip) << "Not enough points to skip"; + lines_id_ = skip; + } + // Read an image, and use it to initialize the top blob. + cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, + new_height, new_width, is_color); + const int channels = cv_img.channels(); + const int height = cv_img.rows; + const int width = cv_img.cols; + // image + const int crop_size = this->layer_param_.transform_param().crop_size(); + const int batch_size = this->layer_param_.image_data_param().batch_size(); + if (crop_size > 0) { + top[0]->Reshape(batch_size, channels, crop_size, crop_size); + this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + this->transformed_data_.Reshape(1, channels, crop_size, crop_size); + } else { + top[0]->Reshape(batch_size, channels, height, width); + this->prefetch_data_.Reshape(batch_size, channels, height, width); + this->transformed_data_.Reshape(1, channels, height, width); + } + LOG(INFO) << "output data size: " << top[0]->num() << "," + << top[0]->channels() << "," << top[0]->height() << "," + << top[0]->width(); + // label + top[1]->Reshape(batch_size, 1, 1, 1); + this->prefetch_label_.Reshape(batch_size, 1, 1, 1); +} + +template +void ImageDataLayer::ShuffleImages() { + caffe::rng_t* prefetch_rng = + static_cast(prefetch_rng_->generator()); + shuffle(lines_.begin(), lines_.end(), prefetch_rng); +} + +// This function is used to create a thread that prefetches the data. +template +void ImageDataLayer::InternalThreadEntry() { + CPUTimer batch_timer; + batch_timer.Start(); + double read_time = 0; + double trans_time = 0; + CPUTimer timer; + CHECK(this->prefetch_data_.count()); + CHECK(this->transformed_data_.count()); + Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + ImageDataParameter image_data_param = this->layer_param_.image_data_param(); + const int batch_size = image_data_param.batch_size(); + const int new_height = image_data_param.new_height(); + const int new_width = image_data_param.new_width(); + const bool is_color = image_data_param.is_color(); + string root_folder = image_data_param.root_folder(); + + // datum scales + const int lines_size = lines_.size(); + for (int item_id = 0; item_id < batch_size; ++item_id) { + // get a blob + timer.Start(); + CHECK_GT(lines_size, lines_id_); + cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, + new_height, new_width, is_color); + if (!cv_img.data) { + continue; + } + read_time += timer.MicroSeconds(); + timer.Start(); + // Apply transformations (mirror, crop...) to the image + int offset = this->prefetch_data_.offset(item_id); + this->transformed_data_.set_cpu_data(top_data + offset); + this->data_transformer_.Transform(cv_img, &(this->transformed_data_)); + trans_time += timer.MicroSeconds(); + + top_label[item_id] = lines_[lines_id_].second; + // go to the next iter + lines_id_++; + if (lines_id_ >= lines_size) { + // We have reached the end. Restart from the first. + DLOG(INFO) << "Restarting data prefetching from start."; + lines_id_ = 0; + if (this->layer_param_.image_data_param().shuffle()) { + ShuffleImages(); + } + } + } + batch_timer.Stop(); + DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; + DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; + DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; +} + +INSTANTIATE_CLASS(ImageDataLayer); +REGISTER_LAYER_CLASS(IMAGE_DATA, ImageDataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/infogain_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/infogain_loss_layer.cpp new file mode 100644 index 00000000..8910431d --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/infogain_loss_layer.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void InfogainLossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + LossLayer::LayerSetUp(bottom, top); + if (bottom.size() < 3) { + CHECK(this->layer_param_.infogain_loss_param().has_source()) + << "Infogain matrix source must be specified."; + BlobProto blob_proto; + ReadProtoFromBinaryFile( + this->layer_param_.infogain_loss_param().source(), &blob_proto); + infogain_.FromProto(blob_proto); + } +} + +template +void InfogainLossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + Blob* infogain = NULL; + if (bottom.size() < 3) { + infogain = &infogain_; + } else { + infogain = bottom[2]; + } + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); + const int num = bottom[0]->num(); + const int dim = bottom[0]->count() / num; + CHECK_EQ(infogain->num(), 1); + CHECK_EQ(infogain->channels(), 1); + CHECK_EQ(infogain->height(), dim); + CHECK_EQ(infogain->width(), dim); +} + + +template +void InfogainLossLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + const Dtype* infogain_mat = NULL; + if (bottom.size() < 3) { + infogain_mat = infogain_.cpu_data(); + } else { + infogain_mat = bottom[2]->cpu_data(); + } + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + Dtype loss = 0; + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + for (int j = 0; j < dim; ++j) { + Dtype prob = std::max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); + loss -= infogain_mat[label * dim + j] * log(prob); + } + } + top[0]->mutable_cpu_data()[0] = loss / num; +} + +template +void InfogainLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down.size() > 2 && propagate_down[2]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to infogain inputs."; + } + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + const Dtype* infogain_mat = NULL; + if (bottom.size() < 3) { + infogain_mat = infogain_.cpu_data(); + } else { + infogain_mat = bottom[2]->cpu_data(); + } + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + const Dtype scale = - top[0]->cpu_diff()[0] / num; + for (int i = 0; i < num; ++i) { + const int label = static_cast(bottom_label[i]); + for (int j = 0; j < dim; ++j) { + Dtype prob = std::max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + j] = scale * infogain_mat[label * dim + j] / prob; + } + } + } +} + +INSTANTIATE_CLASS(InfogainLossLayer); +REGISTER_LAYER_CLASS(INFOGAIN_LOSS, InfogainLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/inner_product_layer.cpp b/caffe-crfrnn/src/caffe/layers/inner_product_layer.cpp new file mode 100644 index 00000000..2ba4d662 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/inner_product_layer.cpp @@ -0,0 +1,108 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void InnerProductLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + const int num_output = this->layer_param_.inner_product_param().num_output(); + bias_term_ = this->layer_param_.inner_product_param().bias_term(); + N_ = num_output; + K_ = bottom[0]->count() / bottom[0]->num(); + // Check if we need to set up the weights + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Intialize the weight + this->blobs_[0].reset(new Blob(1, 1, N_, K_)); + // fill the weights + shared_ptr > weight_filler(GetFiller( + this->layer_param_.inner_product_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, intiialize and fill the bias term + if (bias_term_) { + this->blobs_[1].reset(new Blob(1, 1, 1, N_)); + shared_ptr > bias_filler(GetFiller( + this->layer_param_.inner_product_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } // parameter initialization + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template +void InnerProductLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // Figure out the dimensions + M_ = bottom[0]->num(); + CHECK_EQ(bottom[0]->count() / bottom[0]->num(), K_) << "Input size " + "incompatible with inner product parameters."; + top[0]->Reshape(bottom[0]->num(), N_, 1, 1); + // Set up the bias multiplier + if (bias_term_) { + bias_multiplier_.Reshape(1, 1, 1, M_); + caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); + } +} + +template +void InnerProductLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const Dtype* weight = this->blobs_[0]->cpu_data(); + caffe_cpu_gemm(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1., + bottom_data, weight, (Dtype)0., top_data); + if (bias_term_) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., + bias_multiplier_.cpu_data(), + this->blobs_[1]->cpu_data(), (Dtype)1., top_data); + } +} + +template +void InnerProductLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (this->param_propagate_down_[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + // Gradient with respect to weight + caffe_cpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1., + top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff()); + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->cpu_diff(); + // Gradient with respect to bias + caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, + bias_multiplier_.cpu_data(), (Dtype)1., + this->blobs_[1]->mutable_cpu_diff()); + } + if (propagate_down[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + // Gradient with respect to bottom data + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1., + top_diff, this->blobs_[0]->cpu_data(), (Dtype)0., + bottom[0]->mutable_cpu_diff()); + } +} + +#ifdef CPU_ONLY +STUB_GPU(InnerProductLayer); +#endif + +INSTANTIATE_CLASS(InnerProductLayer); +REGISTER_LAYER_CLASS(INNER_PRODUCT, InnerProductLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/inner_product_layer.cu b/caffe-crfrnn/src/caffe/layers/inner_product_layer.cu new file mode 100644 index 00000000..dd90cac1 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/inner_product_layer.cu @@ -0,0 +1,56 @@ +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void InnerProductLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + caffe_gpu_gemm(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1., + bottom_data, weight, (Dtype)0., top_data); + if (bias_term_) { + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., + bias_multiplier_.gpu_data(), + this->blobs_[1]->gpu_data(), (Dtype)1., top_data); + } +} + +template +void InnerProductLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (this->param_propagate_down_[0]) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + // Gradient with respect to weight + caffe_gpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1., + top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff()); + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->gpu_diff(); + // Gradient with respect to bias + caffe_gpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, + bias_multiplier_.gpu_data(), (Dtype)1., + this->blobs_[1]->mutable_gpu_diff()); + } + if (propagate_down[0]) { + const Dtype* top_diff = top[0]->gpu_diff(); + // Gradient with respect to bottom data + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1., + top_diff, this->blobs_[0]->gpu_data(), (Dtype)0., + bottom[0]->mutable_gpu_diff()); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(InnerProductLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/loss_layer.cpp new file mode 100644 index 00000000..a5b6d11b --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/loss_layer.cpp @@ -0,0 +1,32 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void LossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + // LossLayers have a non-zero (1) loss by default. + if (this->layer_param_.loss_weight_size() == 0) { + this->layer_param_.add_loss_weight(Dtype(1)); + } +} + +template +void LossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + CHECK_EQ(bottom[0]->num(), bottom[1]->num()) + << "The data and label should have the same number."; + top[0]->Reshape(1, 1, 1, 1); +} + +INSTANTIATE_CLASS(LossLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/lrn_layer.cpp b/caffe-crfrnn/src/caffe/layers/lrn_layer.cpp new file mode 100644 index 00000000..a878cf84 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/lrn_layer.cpp @@ -0,0 +1,256 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void LRNLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + size_ = this->layer_param_.lrn_param().local_size(); + CHECK_EQ(size_ % 2, 1) << "LRN only supports odd values for local_size"; + pre_pad_ = (size_ - 1) / 2; + alpha_ = this->layer_param_.lrn_param().alpha(); + beta_ = this->layer_param_.lrn_param().beta(); + k_ = this->layer_param_.lrn_param().k(); + if (this->layer_param_.lrn_param().norm_region() == + LRNParameter_NormRegion_WITHIN_CHANNEL) { + // Set up split_layer_ to use inputs in the numerator and denominator. + split_top_vec_.clear(); + split_top_vec_.push_back(&product_input_); + split_top_vec_.push_back(&square_input_); + LayerParameter split_param; + split_layer_.reset(new SplitLayer(split_param)); + split_layer_->SetUp(bottom, split_top_vec_); + // Set up square_layer_ to square the inputs. + square_bottom_vec_.clear(); + square_top_vec_.clear(); + square_bottom_vec_.push_back(&square_input_); + square_top_vec_.push_back(&square_output_); + LayerParameter square_param; + square_param.mutable_power_param()->set_power(Dtype(2)); + square_layer_.reset(new PowerLayer(square_param)); + square_layer_->SetUp(square_bottom_vec_, square_top_vec_); + // Set up pool_layer_ to sum over square neighborhoods of the input. + pool_top_vec_.clear(); + pool_top_vec_.push_back(&pool_output_); + LayerParameter pool_param; + pool_param.mutable_pooling_param()->set_pool( + PoolingParameter_PoolMethod_AVE); + pool_param.mutable_pooling_param()->set_pad(pre_pad_); + pool_param.mutable_pooling_param()->set_kernel_size(size_); + pool_layer_.reset(new PoolingLayer(pool_param)); + pool_layer_->SetUp(square_top_vec_, pool_top_vec_); + // Set up power_layer_ to compute (1 + alpha_/N^2 s)^-beta_, where s is + // the sum of a squared neighborhood (the output of pool_layer_). + power_top_vec_.clear(); + power_top_vec_.push_back(&power_output_); + LayerParameter power_param; + power_param.mutable_power_param()->set_power(-beta_); + power_param.mutable_power_param()->set_scale(alpha_); + power_param.mutable_power_param()->set_shift(Dtype(1)); + power_layer_.reset(new PowerLayer(power_param)); + power_layer_->SetUp(pool_top_vec_, power_top_vec_); + // Set up a product_layer_ to compute outputs by multiplying inputs by the + // inverse demoninator computed by the power layer. + product_bottom_vec_.clear(); + product_bottom_vec_.push_back(&product_input_); + product_bottom_vec_.push_back(&power_output_); + LayerParameter product_param; + EltwiseParameter* eltwise_param = product_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + product_layer_.reset(new EltwiseLayer(product_param)); + product_layer_->SetUp(product_bottom_vec_, top); + } +} + +template +void LRNLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + num_ = bottom[0]->num(); + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + switch (this->layer_param_.lrn_param().norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + top[0]->Reshape(num_, channels_, height_, width_); + scale_.Reshape(num_, channels_, height_, width_); + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + split_layer_->Reshape(bottom, split_top_vec_); + square_layer_->Reshape(square_bottom_vec_, square_top_vec_); + pool_layer_->Reshape(square_top_vec_, pool_top_vec_); + power_layer_->Reshape(pool_top_vec_, power_top_vec_); + product_layer_->Reshape(product_bottom_vec_, top); + break; + } +} + +template +void LRNLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + switch (this->layer_param_.lrn_param().norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + CrossChannelForward_cpu(bottom, top); + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + WithinChannelForward(bottom, top); + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +template +void LRNLayer::CrossChannelForward_cpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + Dtype* scale_data = scale_.mutable_cpu_data(); + // start with the constant value + for (int i = 0; i < scale_.count(); ++i) { + scale_data[i] = k_; + } + Blob padded_square(1, channels_ + size_ - 1, height_, width_); + Dtype* padded_square_data = padded_square.mutable_cpu_data(); + caffe_set(padded_square.count(), Dtype(0), padded_square_data); + Dtype alpha_over_size = alpha_ / size_; + // go through the images + for (int n = 0; n < num_; ++n) { + // compute the padded square + caffe_sqr(channels_ * height_ * width_, + bottom_data + bottom[0]->offset(n), + padded_square_data + padded_square.offset(0, pre_pad_)); + // Create the first channel scale + for (int c = 0; c < size_; ++c) { + caffe_axpy(height_ * width_, alpha_over_size, + padded_square_data + padded_square.offset(0, c), + scale_data + scale_.offset(n, 0)); + } + for (int c = 1; c < channels_; ++c) { + // copy previous scale + caffe_copy(height_ * width_, + scale_data + scale_.offset(n, c - 1), + scale_data + scale_.offset(n, c)); + // add head + caffe_axpy(height_ * width_, alpha_over_size, + padded_square_data + padded_square.offset(0, c + size_ - 1), + scale_data + scale_.offset(n, c)); + // subtract tail + caffe_axpy(height_ * width_, -alpha_over_size, + padded_square_data + padded_square.offset(0, c - 1), + scale_data + scale_.offset(n, c)); + } + } + + // In the end, compute output + caffe_powx(scale_.count(), scale_data, -beta_, top_data); + caffe_mul(scale_.count(), top_data, bottom_data, top_data); +} + +template +void LRNLayer::WithinChannelForward( + const vector*>& bottom, const vector*>& top) { + split_layer_->Forward(bottom, split_top_vec_); + square_layer_->Forward(square_bottom_vec_, square_top_vec_); + pool_layer_->Forward(square_top_vec_, pool_top_vec_); + power_layer_->Forward(pool_top_vec_, power_top_vec_); + product_layer_->Forward(product_bottom_vec_, top); +} + +template +void LRNLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + switch (this->layer_param_.lrn_param().norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + CrossChannelBackward_cpu(top, propagate_down, bottom); + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + WithinChannelBackward(top, propagate_down, bottom); + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +template +void LRNLayer::CrossChannelBackward_cpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* scale_data = scale_.cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + Blob padded_ratio(1, channels_ + size_ - 1, height_, width_); + Blob accum_ratio(1, 1, height_, width_); + Dtype* padded_ratio_data = padded_ratio.mutable_cpu_data(); + Dtype* accum_ratio_data = accum_ratio.mutable_cpu_data(); + // We hack a little bit by using the diff() to store an additional result + Dtype* accum_ratio_times_bottom = accum_ratio.mutable_cpu_diff(); + caffe_set(padded_ratio.count(), Dtype(0), padded_ratio_data); + Dtype cache_ratio_value = 2. * alpha_ * beta_ / size_; + + caffe_powx(scale_.count(), scale_data, -beta_, bottom_diff); + caffe_mul(scale_.count(), top_diff, bottom_diff, bottom_diff); + + // go through individual data + int inverse_pre_pad = size_ - (size_ + 1) / 2; + for (int n = 0; n < num_; ++n) { + int block_offset = scale_.offset(n); + // first, compute diff_i * y_i / s_i + caffe_mul(channels_ * height_ * width_, + top_diff + block_offset, top_data + block_offset, + padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad)); + caffe_div(channels_ * height_ * width_, + padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad), + scale_data + block_offset, + padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad)); + // Now, compute the accumulated ratios and the bottom diff + caffe_set(accum_ratio.count(), Dtype(0), accum_ratio_data); + for (int c = 0; c < size_ - 1; ++c) { + caffe_axpy(height_ * width_, 1., + padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data); + } + for (int c = 0; c < channels_; ++c) { + caffe_axpy(height_ * width_, 1., + padded_ratio_data + padded_ratio.offset(0, c + size_ - 1), + accum_ratio_data); + // compute bottom diff + caffe_mul(height_ * width_, + bottom_data + top[0]->offset(n, c), + accum_ratio_data, accum_ratio_times_bottom); + caffe_axpy(height_ * width_, -cache_ratio_value, + accum_ratio_times_bottom, bottom_diff + top[0]->offset(n, c)); + caffe_axpy(height_ * width_, -1., + padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data); + } + } +} + +template +void LRNLayer::WithinChannelBackward( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + vector product_propagate_down(2, true); + product_layer_->Backward(top, product_propagate_down, product_bottom_vec_); + power_layer_->Backward(power_top_vec_, propagate_down, pool_top_vec_); + pool_layer_->Backward(pool_top_vec_, propagate_down, square_top_vec_); + square_layer_->Backward(square_top_vec_, propagate_down, + square_bottom_vec_); + split_layer_->Backward(split_top_vec_, propagate_down, bottom); + } +} + +#ifdef CPU_ONLY +STUB_GPU(LRNLayer); +STUB_GPU_FORWARD(LRNLayer, CrossChannelForward); +STUB_GPU_BACKWARD(LRNLayer, CrossChannelBackward); +#endif + +INSTANTIATE_CLASS(LRNLayer); +REGISTER_LAYER_CLASS(LRN, LRNLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/lrn_layer.cu b/caffe-crfrnn/src/caffe/layers/lrn_layer.cu new file mode 100644 index 00000000..58c39926 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/lrn_layer.cu @@ -0,0 +1,206 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void LRNFillScale(const int nthreads, const Dtype* in, + const int num, const int channels, const int height, + const int width, const int size, const Dtype alpha_over_size, + const Dtype k, Dtype* scale) { + CUDA_KERNEL_LOOP(index, nthreads) { + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int n = index / width / height; + int offset = (n * channels * height + h) * width + w; + int step = height * width; + in += offset; + scale += offset; + int head = 0; + int pre_pad = (size - 1) / 2; + int post_pad = size - pre_pad - 1; + Dtype accum_scale = 0; + // fill the scale at [n, :, h, w] + // accumulate values + while (head < post_pad) { + accum_scale += in[head * step] * in[head * step]; + ++head; + } + // until we reach size, nothing needs to be subtracted + while (head < size) { + accum_scale += in[head * step] * in[head * step]; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; + ++head; + } + // both add and subtract + while (head < channels) { + accum_scale += in[head * step] * in[head * step]; + accum_scale -= in[(head - size) * step] * in[(head - size) * step]; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; + ++head; + } + // subtract only + while (head < channels + post_pad) { + accum_scale -= in[(head - size) * step] * in[(head - size) * step]; + scale[(head - post_pad) * step] = k + accum_scale * alpha_over_size; + ++head; + } + } +} + + +template +void LRNLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + switch (this->layer_param_.lrn_param().norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + CrossChannelForward_gpu(bottom, top); + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + WithinChannelForward(bottom, top); + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +// TODO: check if it would be faster to just put it into the previous kernel. +template +__global__ void LRNComputeOutput(const int nthreads, const Dtype* in, + const Dtype* scale, const Dtype negative_beta, Dtype* out) { + CUDA_KERNEL_LOOP(index, nthreads) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template +void LRNLayer::CrossChannelForward_gpu( + const vector*>& bottom, const vector*>& top) { + // First, compute scale + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + Dtype* scale_data = scale_.mutable_gpu_data(); + // We will launch one kernel for each pixel location, and have the kernel + // go through all the channels. + int n_threads = num_ * height_ * width_; + // NOLINT_NEXT_LINE(whitespace/operators) + LRNFillScale<<>>( + n_threads, bottom_data, num_, channels_, height_, width_, size_, + alpha_ / size_, k_, scale_data); + CUDA_POST_KERNEL_CHECK; + n_threads = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + LRNComputeOutput<<>>( + n_threads, bottom_data, scale_data, -beta_, top_data); + CUDA_POST_KERNEL_CHECK; +} +template void LRNLayer::CrossChannelForward_gpu( + const vector*>& bottom, const vector*>& top); +template void LRNLayer::CrossChannelForward_gpu( + const vector*>& bottom, const vector*>& top); + + +template +void LRNLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + switch (this->layer_param_.lrn_param().norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + CrossChannelBackward_gpu(top, propagate_down, bottom); + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + WithinChannelBackward(top, propagate_down, bottom); + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +template +__global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data, + const Dtype* top_data, const Dtype* scale, const Dtype* top_diff, + const int num, const int channels, const int height, + const int width, const int size, const Dtype negative_beta, + const Dtype cache_ratio, + Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int n = index / width / height; + int offset = (n * channels * height + h) * width + w; + int step = height * width; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + int head = 0; + int pre_pad = size - (size + 1) / 2; + int post_pad = size - pre_pad - 1; + Dtype accum_ratio = 0; + // accumulate values + while (head < post_pad) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + ++head; + } + // until we reach size, nothing needs to be subtracted + while (head < size) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + // both add and subtract + while (head < channels) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + accum_ratio -= top_diff[(head - size) * step] * + top_data[(head - size) * step] / scale[(head - size) * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + // subtract only + while (head < channels + post_pad) { + accum_ratio -= top_diff[(head - size) * step] * + top_data[(head - size) * step] / scale[(head - size) * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + } +} + +template +void LRNLayer::CrossChannelBackward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + int n_threads = num_ * height_ * width_; + // NOLINT_NEXT_LINE(whitespace/operators) + LRNComputeDiff<<>>( + n_threads, bottom[0]->gpu_data(), top[0]->gpu_data(), + scale_.gpu_data(), top[0]->gpu_diff(), num_, channels_, height_, width_, + size_, -beta_, Dtype(2. * alpha_ * beta_ / size_), + bottom[0]->mutable_gpu_diff()); +} +template void LRNLayer::CrossChannelBackward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom); +template void LRNLayer::CrossChannelBackward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom); + + + +INSTANTIATE_LAYER_GPU_FUNCS(LRNLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/meanfield_iteration.cpp b/caffe-crfrnn/src/caffe/layers/meanfield_iteration.cpp new file mode 100755 index 00000000..d6efc3ac --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/meanfield_iteration.cpp @@ -0,0 +1,284 @@ +/*! + * \brief A helper class for {@link MultiStageMeanfieldLayer} class, which is the Caffe layer that implements the + * CRF-RNN described in the paper: Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * This class itself is not a proper Caffe layer although it behaves like one to some degree. + * + * \authors Sadeep Jayasumana, Bernardino Romera-Paredes, Shuai Zheng, Zhizhong Su. + * \version 1.0 + * \date 2015 + * \copyright Torr Vision Group, University of Oxford. + * \details If you use this code, please consider citing the paper: + * Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, + * Chang Huang, Philip H. S. Torr. Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * For more information about CRF-RNN, please visit the project website http://crfasrnn.torr.vision. + */ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +/** + * To be invoked once only immediately after construction. + */ +template +void MeanfieldIteration::OneTimeSetUp( + Blob* const unary_terms, + Blob* const softmax_input, + Blob* const output_blob, + const shared_ptr spatial_lattice, + const Blob* const spatial_norm) { + + spatial_lattice_ = spatial_lattice; + spatial_norm_ = spatial_norm; + + count_ = unary_terms->count(); + num_ = unary_terms->num(); + channels_ = unary_terms->channels(); + height_ = unary_terms->height(); + width_ = unary_terms->width(); + num_pixels_ = height_ * width_; + + if (this->blobs_.size() > 0) { + LOG(INFO) << "Meanfield iteration skipping parameter initialization."; + } else { + blobs_.resize(3); + blobs_[0].reset(new Blob(1, 1, channels_, channels_)); // spatial kernel weight + blobs_[1].reset(new Blob(1, 1, channels_, channels_)); // bilateral kernel weight + blobs_[2].reset(new Blob(1, 1, channels_, channels_)); // compatibility transform matrix + } + + pairwise_.Reshape(num_, channels_, height_, width_); + spatial_out_blob_.Reshape(num_, channels_, height_, width_); + bilateral_out_blob_.Reshape(num_, channels_, height_, width_); + message_passing_.Reshape(num_, channels_, height_, width_); + + // Softmax layer configuration + softmax_bottom_vec_.clear(); + softmax_bottom_vec_.push_back(softmax_input); + + softmax_top_vec_.clear(); + softmax_top_vec_.push_back(&prob_); + + LayerParameter softmax_param; + softmax_layer_.reset(new SoftmaxLayer(softmax_param)); + softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_); + + // Sum layer configuration + sum_bottom_vec_.clear(); + sum_bottom_vec_.push_back(unary_terms); + sum_bottom_vec_.push_back(&pairwise_); + + sum_top_vec_.clear(); + sum_top_vec_.push_back(output_blob); + + LayerParameter sum_param; + sum_param.mutable_eltwise_param()->add_coeff(Dtype(1.)); + sum_param.mutable_eltwise_param()->add_coeff(Dtype(-1.)); + sum_param.mutable_eltwise_param()->set_operation(EltwiseParameter_EltwiseOp_SUM); + sum_layer_.reset(new EltwiseLayer(sum_param)); + sum_layer_->SetUp(sum_bottom_vec_, sum_top_vec_); +} + +/** + * To be invoked before every call to the Forward_cpu() method. + */ +template +void MeanfieldIteration::PrePass( + const vector > >& parameters_to_copy_from, + const vector >* const bilateral_lattices, + const Blob* const bilateral_norms) { + + bilateral_lattices_ = bilateral_lattices; + bilateral_norms_ = bilateral_norms; + + // Get copies of the up-to-date parameters. + for (int i = 0; i < parameters_to_copy_from.size(); ++i) { + blobs_[i]->CopyFrom(*(parameters_to_copy_from[i].get())); + } +} + +/** + * Forward pass during the inference. + */ +template +void MeanfieldIteration::Forward_cpu() { + + + //------------------------------- Softmax normalization-------------------- + softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_); + + //-----------------------------------Message passing----------------------- + for (int n = 0; n < num_; ++n) { + + Dtype* spatial_out_data = spatial_out_blob_.mutable_cpu_data() + spatial_out_blob_.offset(n); + const Dtype* prob_input_data = prob_.cpu_data() + prob_.offset(n); + + spatial_lattice_->compute(spatial_out_data, prob_input_data, channels_, false); + + // Pixel-wise normalization. + for (int channel_id = 0; channel_id < channels_; ++channel_id) { + caffe_mul(num_pixels_, spatial_norm_->cpu_data(), + spatial_out_data + channel_id * num_pixels_, + spatial_out_data + channel_id * num_pixels_); + } + + Dtype* bilateral_out_data = bilateral_out_blob_.mutable_cpu_data() + bilateral_out_blob_.offset(n); + + (*bilateral_lattices_)[n]->compute(bilateral_out_data, prob_input_data, channels_, false); + // Pixel-wise normalization. + for (int channel_id = 0; channel_id < channels_; ++channel_id) { + caffe_mul(num_pixels_, bilateral_norms_->cpu_data() + bilateral_norms_->offset(n), + bilateral_out_data + channel_id * num_pixels_, + bilateral_out_data + channel_id * num_pixels_); + } + } + + caffe_set(count_, Dtype(0.), message_passing_.mutable_cpu_data()); + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., + this->blobs_[0]->cpu_data(), spatial_out_blob_.cpu_data() + spatial_out_blob_.offset(n), (Dtype) 0., + message_passing_.mutable_cpu_data() + message_passing_.offset(n)); + } + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., + this->blobs_[1]->cpu_data(), bilateral_out_blob_.cpu_data() + bilateral_out_blob_.offset(n), (Dtype) 1., + message_passing_.mutable_cpu_data() + message_passing_.offset(n)); + } + + //--------------------------- Compatibility multiplication ---------------- + //Result from message passing needs to be multiplied with compatibility values. + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, + channels_, (Dtype) 1., this->blobs_[2]->cpu_data(), + message_passing_.cpu_data() + message_passing_.offset(n), (Dtype) 0., + pairwise_.mutable_cpu_data() + pairwise_.offset(n)); + } + + //------------------------- Adding unaries, normalization is left to the next iteration -------------- + // Add unary + sum_layer_->Forward(sum_bottom_vec_, sum_top_vec_); +} + + +template +void MeanfieldIteration::Backward_cpu() { + + + //---------------------------- Add unary gradient -------------------------- + vector eltwise_propagate_down(2, true); + sum_layer_->Backward(sum_top_vec_, eltwise_propagate_down, sum_bottom_vec_); + + //---------------------------- Update compatibility diffs ------------------ + caffe_set(this->blobs_[2]->count(), Dtype(0.), this->blobs_[2]->mutable_cpu_diff()); + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, channels_, channels_, + num_pixels_, (Dtype) 1., pairwise_.cpu_diff() + pairwise_.offset(n), + message_passing_.cpu_data() + message_passing_.offset(n), (Dtype) 1., + this->blobs_[2]->mutable_cpu_diff()); + } + + //-------------------------- Gradient after compatibility transform--- ----- + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, channels_, num_pixels_, + channels_, (Dtype) 1., this->blobs_[2]->cpu_data(), + pairwise_.cpu_diff() + pairwise_.offset(n), (Dtype) 0., + message_passing_.mutable_cpu_diff() + message_passing_.offset(n)); + } + + // ------------------------- Gradient w.r.t. kernels weights ------------ + caffe_set(this->blobs_[0]->count(), Dtype(0.), this->blobs_[0]->mutable_cpu_diff()); + caffe_set(this->blobs_[1]->count(), Dtype(0.), this->blobs_[1]->mutable_cpu_diff()); + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, channels_, channels_, + num_pixels_, (Dtype) 1., message_passing_.cpu_diff() + message_passing_.offset(n), + spatial_out_blob_.cpu_data() + spatial_out_blob_.offset(n), (Dtype) 1., + this->blobs_[0]->mutable_cpu_diff()); + } + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, channels_, channels_, + num_pixels_, (Dtype) 1., message_passing_.cpu_diff() + message_passing_.offset(n), + bilateral_out_blob_.cpu_data() + bilateral_out_blob_.offset(n), (Dtype) 1., + this->blobs_[1]->mutable_cpu_diff()); + } + + /*Dtype* tmp = new Dtype[count_]; + caffe_mul(count_, message_passing_.cpu_diff(), spatial_out_blob_.cpu_data(), tmp); + + for (int c = 0; c < count_; ++c) { + (this->blobs_[0]->mutable_cpu_diff())[0] += tmp[c]; + } + + caffe_mul(count_, message_passing_.cpu_diff(), bilateral_out_blob_.cpu_data(), tmp); + for (int c = 0; c < count_; ++c) { + (this->blobs_[1]->mutable_cpu_diff())[0] += tmp[c]; + } + + delete[] tmp;*/ + + // TODO: Check whether there's a way to improve the accuracy of this calculation. + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., + this->blobs_[0]->cpu_data(), message_passing_.cpu_diff() + message_passing_.offset(n), + (Dtype) 0., + spatial_out_blob_.mutable_cpu_diff() + spatial_out_blob_.offset(n)); + } + //caffe_cpu_scale(count_, (this->blobs_[0]->cpu_data())[0], + // message_passing_.cpu_diff(), spatial_out_blob_.mutable_cpu_diff()); + + for (int n = 0; n < num_; ++n) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., + this->blobs_[1]->cpu_data(), message_passing_.cpu_diff() + message_passing_.offset(n), + (Dtype) 0., + bilateral_out_blob_.mutable_cpu_diff() + bilateral_out_blob_.offset(n)); + } + //caffe_cpu_scale(count_, (this->blobs_[1]->cpu_data())[0], + // message_passing_.cpu_diff(), bilateral_out_blob_.mutable_cpu_diff()); + + + //---------------------------- BP thru normalization -------------------------- + for (int n = 0; n < num_; ++n) { + + Dtype *spatial_out_diff = spatial_out_blob_.mutable_cpu_diff() + spatial_out_blob_.offset(n); + for (int channel_id = 0; channel_id < channels_; ++channel_id) { + caffe_mul(num_pixels_, spatial_norm_->cpu_data(), + spatial_out_diff + channel_id * num_pixels_, + spatial_out_diff + channel_id * num_pixels_); + } + + Dtype *bilateral_out_diff = bilateral_out_blob_.mutable_cpu_diff() + bilateral_out_blob_.offset(n); + for (int channel_id = 0; channel_id < channels_; ++channel_id) { + caffe_mul(num_pixels_, bilateral_norms_->cpu_data() + bilateral_norms_->offset(n), + bilateral_out_diff + channel_id * num_pixels_, + bilateral_out_diff + channel_id * num_pixels_); + } + } + + //--------------------------- Gradient for message passing --------------- + for (int n = 0; n < num_; ++n) { + + spatial_lattice_->compute(prob_.mutable_cpu_diff() + prob_.offset(n), + spatial_out_blob_.cpu_diff() + spatial_out_blob_.offset(n), channels_, + true, false); + + (*bilateral_lattices_)[n]->compute(prob_.mutable_cpu_diff() + prob_.offset(n), + bilateral_out_blob_.cpu_diff() + bilateral_out_blob_.offset(n), + channels_, true, true); + } + + //-------------------------------------------------------------------------------- + vector propagate_down(2, true); + softmax_layer_->Backward(softmax_top_vec_, propagate_down, softmax_bottom_vec_); +} + +INSTANTIATE_CLASS(MeanfieldIteration); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/memory_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/memory_data_layer.cpp new file mode 100644 index 00000000..613ca2d4 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/memory_data_layer.cpp @@ -0,0 +1,76 @@ +#include + +#include "caffe/data_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" + +namespace caffe { + +template +void MemoryDataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + batch_size_ = this->layer_param_.memory_data_param().batch_size(); + channels_ = this->layer_param_.memory_data_param().channels(); + height_ = this->layer_param_.memory_data_param().height(); + width_ = this->layer_param_.memory_data_param().width(); + size_ = channels_ * height_ * width_; + CHECK_GT(batch_size_ * size_, 0) << + "batch_size, channels, height, and width must be specified and" + " positive in memory_data_param"; + top[0]->Reshape(batch_size_, channels_, height_, width_); + top[1]->Reshape(batch_size_, 1, 1, 1); + added_data_.Reshape(batch_size_, channels_, height_, width_); + added_label_.Reshape(batch_size_, 1, 1, 1); + data_ = NULL; + labels_ = NULL; + added_data_.cpu_data(); + added_label_.cpu_data(); +} + +template +void MemoryDataLayer::AddDatumVector(const vector& datum_vector) { + CHECK(!has_new_data_) << + "Can't add Datum when earlier ones haven't been consumed" + << " by the upper layers"; + size_t num = datum_vector.size(); + CHECK_GT(num, 0) << "There is no datum to add"; + CHECK_LE(num, batch_size_) << + "The number of added datum must be no greater than the batch size"; + + // Apply data transformations (mirror, scale, crop...) + this->data_transformer_.Transform(datum_vector, &added_data_); + // Copy Labels + Dtype* top_label = added_label_.mutable_cpu_data(); + for (int item_id = 0; item_id < num; ++item_id) { + top_label[item_id] = datum_vector[item_id].label(); + } + // num_images == batch_size_ + Dtype* top_data = added_data_.mutable_cpu_data(); + Reset(top_data, top_label, batch_size_); + has_new_data_ = true; +} + +template +void MemoryDataLayer::Reset(Dtype* data, Dtype* labels, int n) { + CHECK(data); + CHECK(labels); + CHECK_EQ(n % batch_size_, 0) << "n must be a multiple of batch size"; + data_ = data; + labels_ = labels; + n_ = n; + pos_ = 0; +} + +template +void MemoryDataLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset"; + top[0]->set_cpu_data(data_ + pos_ * size_); + top[1]->set_cpu_data(labels_ + pos_); + pos_ = (pos_ + batch_size_) % n_; + has_new_data_ = false; +} + +INSTANTIATE_CLASS(MemoryDataLayer); +REGISTER_LAYER_CLASS(MEMORY_DATA, MemoryDataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/multi_stage_meanfield.cpp b/caffe-crfrnn/src/caffe/layers/multi_stage_meanfield.cpp new file mode 100755 index 00000000..ae59656f --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/multi_stage_meanfield.cpp @@ -0,0 +1,256 @@ +/*! + * \brief The Caffe layer that implements the CRF-RNN described in the paper: + * Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * \authors Sadeep Jayasumana, Bernardino Romera-Paredes, Shuai Zheng, Zhizhong Su. + * \version 1.0 + * \date 2015 + * \copyright Torr Vision Group, University of Oxford. + * \details If you use this code, please consider citing the paper: + * Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, + * Chang Huang, Philip H. S. Torr. Conditional Random Fields as Recurrent Neural Networks. IEEE ICCV 2015. + * + * For more information about CRF-RNN, please visit the project website http://crfasrnn.torr.vision. + */ +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void MultiStageMeanfieldLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + + const caffe::MultiStageMeanfieldParameter meanfield_param = this->layer_param_.multi_stage_meanfield_param(); + + num_iterations_ = meanfield_param.num_iterations(); + + CHECK_GT(num_iterations_, 1) << "Number of iterations must be greater than 1."; + + theta_alpha_ = meanfield_param.theta_alpha(); + theta_beta_ = meanfield_param.theta_beta(); + theta_gamma_ = meanfield_param.theta_gamma(); + + count_ = bottom[0]->count(); + num_ = bottom[0]->num(); + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + num_pixels_ = height_ * width_; + + LOG(INFO) << "This implementation has not been tested batch size > 1."; + + top[0]->Reshape(num_, channels_, height_, width_); + + // Initialize the parameters that will updated by backpropagation. + if (this->blobs_.size() > 0) { + LOG(INFO) << "Multimeanfield layer skipping parameter initialization."; + } else { + + this->blobs_.resize(3);// blobs_[0] - spatial kernel weights, blobs_[1] - bilateral kernel weights, blobs_[2] - compatability matrix + + // Allocate space for kernel weights. + this->blobs_[0].reset(new Blob(1, 1, channels_, channels_)); + this->blobs_[1].reset(new Blob(1, 1, channels_, channels_)); + + caffe_set(channels_ * channels_, Dtype(0.), this->blobs_[0]->mutable_cpu_data()); + caffe_set(channels_ * channels_, Dtype(0.), this->blobs_[1]->mutable_cpu_data()); + + // Initialize the kernels weights. The two files spatial.par and bilateral.par should be available. + FILE * pFile; + pFile = fopen("spatial.par", "r"); + CHECK(pFile) << "The file 'spatial.par' is not found. Please create it with initial spatial kernel weights."; + for (int i = 0; i < channels_; i++) { + fscanf(pFile, "%lf", &this->blobs_[0]->mutable_cpu_data()[i * channels_ + i]); + } + fclose(pFile); + + pFile = fopen("bilateral.par", "r"); + CHECK(pFile) << "The file 'bilateral.par' is not found. Please create it with initial bilateral kernel weights."; + for (int i = 0; i < channels_; i++) { + fscanf(pFile, "%lf", &this->blobs_[1]->mutable_cpu_data()[i * channels_ + i]); + } + fclose(pFile); + + // Initialize the compatibility matrix. + this->blobs_[2].reset(new Blob(1, 1, channels_, channels_)); + caffe_set(channels_ * channels_, Dtype(0.), this->blobs_[2]->mutable_cpu_data()); + + // Initialize it to have the Potts model. + for (int c = 0; c < channels_; ++c) { + (this->blobs_[2]->mutable_cpu_data())[c * channels_ + c] = Dtype(-1.); + } + } + + // Initialize the spatial lattice. This does not need to be computed for every image because we use a fixed size. + float spatial_kernel[2 * num_pixels_]; + compute_spatial_kernel(spatial_kernel); + spatial_lattice_.reset(new ModifiedPermutohedral()); + spatial_lattice_->init(spatial_kernel, 2, num_pixels_); + + // Calculate spatial filter normalization factors. + norm_feed_.reset(new Dtype[num_pixels_]); + caffe_set(num_pixels_, Dtype(1.0), norm_feed_.get()); + spatial_norm_.Reshape(1, 1, height_, width_); + Dtype* norm_data = spatial_norm_.mutable_cpu_data(); + spatial_lattice_->compute(norm_data, norm_feed_.get(), 1); + for (int i = 0; i < num_pixels_; ++i) { + norm_data[i] = 1.0f / (norm_data[i] + 1e-20f); + } + + // Allocate space for bilateral kernels. This is a temporary buffer used to compute bilateral lattices later. + // Also allocate space for holding bilateral filter normalization values. + bilateral_kernel_buffer_.reset(new float[5 * num_pixels_]); + bilateral_norms_.Reshape(num_, 1, height_, width_); + + // Configure the split layer that is used to make copies of the unary term. One copy for each iteration. + // It may be possible to optimize this calculation later. + split_layer_bottom_vec_.clear(); + split_layer_bottom_vec_.push_back(bottom[0]); + + split_layer_top_vec_.clear(); + + split_layer_out_blobs_.resize(num_iterations_); + for (int i = 0; i < num_iterations_; i++) { + split_layer_out_blobs_[i].reset(new Blob()); + split_layer_top_vec_.push_back(split_layer_out_blobs_[i].get()); + } + + LayerParameter split_layer_param; + split_layer_.reset(new SplitLayer(split_layer_param)); + split_layer_->SetUp(split_layer_bottom_vec_, split_layer_top_vec_); + + // Make blobs to store outputs of each meanfield iteration. Output of the last iteration is stored in top[0]. + // So we need only (num_iterations_ - 1) blobs. + iteration_output_blobs_.resize(num_iterations_ - 1); + for (int i = 0; i < num_iterations_ - 1; ++i) { + iteration_output_blobs_[i].reset(new Blob(num_, channels_, height_, width_)); + } + + // Make instances of MeanfieldIteration and initialize them. + meanfield_iterations_.resize(num_iterations_); + for (int i = 0; i < num_iterations_; ++i) { + meanfield_iterations_[i].reset(new MeanfieldIteration()); + meanfield_iterations_[i]->OneTimeSetUp( + split_layer_out_blobs_[i].get(), // unary terms + (i == 0) ? bottom[1] : iteration_output_blobs_[i - 1].get(), // softmax input + (i == num_iterations_ - 1) ? top[0] : iteration_output_blobs_[i].get(), // output blob + spatial_lattice_, // spatial lattice + &spatial_norm_); // spatial normalization factors. + } + + this->param_propagate_down_.resize(this->blobs_.size(), true); + + LOG(INFO) << ("MultiStageMeanfieldLayer initialized."); +} + +template +void MultiStageMeanfieldLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + // Do nothing. +} + + +/** + * Performs filter-based mean field inference given the image and unaries. + * + * bottom[0] - Unary terms + * bottom[1] - Softmax input/Output from the previous iteration (a copy of the unary terms if this is the first stage). + * bottom[2] - RGB images + * + * top[0] - Output of the mean field inference (not normalized). + */ +template +void MultiStageMeanfieldLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + + split_layer_bottom_vec_[0] = bottom[0]; + split_layer_->Forward(split_layer_bottom_vec_, split_layer_top_vec_); + + // Initialize the bilateral lattices. + bilateral_lattices_.resize(num_); + for (int n = 0; n < num_; ++n) { + + compute_bilateral_kernel(bottom[2], n, bilateral_kernel_buffer_.get()); + bilateral_lattices_[n].reset(new ModifiedPermutohedral()); + bilateral_lattices_[n]->init(bilateral_kernel_buffer_.get(), 5, num_pixels_); + + // Calculate bilateral filter normalization factors. + Dtype* norm_output_data = bilateral_norms_.mutable_cpu_data() + bilateral_norms_.offset(n); + bilateral_lattices_[n]->compute(norm_output_data, norm_feed_.get(), 1); + for (int i = 0; i < num_pixels_; ++i) { + norm_output_data[i] = 1.f / (norm_output_data[i] + 1e-20f); + } + } + + for (int i = 0; i < num_iterations_; ++i) { + + meanfield_iterations_[i]->PrePass(this->blobs_, &bilateral_lattices_, &bilateral_norms_); + + meanfield_iterations_[i]->Forward_cpu(); + } +} + +/** + * Backprop through filter-based mean field inference. + */ +template +void MultiStageMeanfieldLayer::Backward_cpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + + for (int i = (num_iterations_ - 1); i >= 0; --i) { + meanfield_iterations_[i]->Backward_cpu(); + } + + vector split_layer_propagate_down(1, true); + split_layer_->Backward(split_layer_top_vec_, split_layer_propagate_down, split_layer_bottom_vec_); + + // Accumulate diffs from mean field iterations. + for (int blob_id = 0; blob_id < this->blobs_.size(); ++blob_id) { + + Blob* cur_blob = this->blobs_[blob_id].get(); + + if (this->param_propagate_down_[blob_id]) { + + caffe_set(cur_blob->count(), Dtype(0), cur_blob->mutable_cpu_diff()); + + for (int i = 0; i < num_iterations_; ++i) { + const Dtype* diffs_to_add = meanfield_iterations_[i]->blobs()[blob_id]->cpu_diff(); + caffe_axpy(cur_blob->count(), Dtype(1.), diffs_to_add, cur_blob->mutable_cpu_diff()); + } + } + } +} + +template +void MultiStageMeanfieldLayer::compute_bilateral_kernel(const Blob* const rgb_blob, const int n, + float* const output_kernel) { + + for (int p = 0; p < num_pixels_; ++p) { + output_kernel[5 * p] = static_cast(p % width_) / theta_alpha_; + output_kernel[5 * p + 1] = static_cast(p / width_) / theta_alpha_; + + const Dtype * const rgb_data_start = rgb_blob->cpu_data() + rgb_blob->offset(n); + output_kernel[5 * p + 2] = static_cast(rgb_data_start[p] / theta_beta_); + output_kernel[5 * p + 3] = static_cast((rgb_data_start + num_pixels_)[p] / theta_beta_); + output_kernel[5 * p + 4] = static_cast((rgb_data_start + num_pixels_ * 2)[p] / theta_beta_); + } +} + +template +void MultiStageMeanfieldLayer::compute_spatial_kernel(float* const output_kernel) { + + for (int p = 0; p < num_pixels_; ++p) { + output_kernel[2*p] = static_cast(p % width_) / theta_gamma_; + output_kernel[2*p + 1] = static_cast(p / width_) / theta_gamma_; + } +} + +INSTANTIATE_CLASS(MultiStageMeanfieldLayer); +REGISTER_LAYER_CLASS(MULTI_STAGE_MEANFIELD, MultiStageMeanfieldLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/multinomial_logistic_loss_layer.cpp new file mode 100644 index 00000000..78a1f60f --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/multinomial_logistic_loss_layer.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void MultinomialLogisticLossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[1]->channels(), 1); + CHECK_EQ(bottom[1]->height(), 1); + CHECK_EQ(bottom[1]->width(), 1); +} + +template +void MultinomialLogisticLossLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + Dtype loss = 0; + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + Dtype prob = std::max( + bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); + loss -= log(prob); + } + top[0]->mutable_cpu_data()[0] = loss / num; +} + +template +void MultinomialLogisticLossLayer::Backward_cpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* bottom_label = bottom[1]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + int num = bottom[0]->num(); + int dim = bottom[0]->count() / bottom[0]->num(); + caffe_set(bottom[0]->count(), Dtype(0), bottom_diff); + const Dtype scale = - top[0]->cpu_diff()[0] / num; + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + Dtype prob = std::max( + bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + label] = scale / prob; + } + } +} + +INSTANTIATE_CLASS(MultinomialLogisticLossLayer); +REGISTER_LAYER_CLASS(MULTINOMIAL_LOGISTIC_LOSS, MultinomialLogisticLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/mvn_layer.cpp b/caffe-crfrnn/src/caffe/layers/mvn_layer.cpp new file mode 100644 index 00000000..104ad95c --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/mvn_layer.cpp @@ -0,0 +1,163 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void MVNLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + mean_.Reshape(bottom[0]->num(), bottom[0]->channels(), + 1, 1); + variance_.Reshape(bottom[0]->num(), bottom[0]->channels(), + 1, 1); + temp_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + sum_multiplier_.Reshape(1, 1, + bottom[0]->height(), bottom[0]->width()); + Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); + caffe_set(sum_multiplier_.count(), Dtype(1), multiplier_data); +} + +template +void MVNLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int num; + if (this->layer_param_.mvn_param().across_channels()) + num = bottom[0]->num(); + else + num = bottom[0]->num() * bottom[0]->channels(); + + int dim = bottom[0]->count() / num; + Dtype eps = 1e-10; + + if (this->layer_param_.mvn_param().normalize_variance()) { + // put the squares of bottom into temp_ + caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_cpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_cpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); // EX + caffe_cpu_gemv(CblasNoTrans, num, dim, 1. / dim, temp_.cpu_data(), + sum_multiplier_.cpu_data(), 0., + variance_.mutable_cpu_data()); // E(X^2) + caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2), + temp_.mutable_cpu_data()); // (EX)^2 + caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(), + variance_.mutable_cpu_data()); // variance + + // do mean and variance normalization + // subtract mean + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.cpu_data(), sum_multiplier_.cpu_data(), 0., + temp_.mutable_cpu_data()); + + caffe_add(temp_.count(), bottom_data, temp_.cpu_data(), top_data); + + // normalize variance + caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5), + variance_.mutable_cpu_data()); + + caffe_add_scalar(variance_.count(), eps, variance_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + variance_.cpu_data(), sum_multiplier_.cpu_data(), 0., + temp_.mutable_cpu_data()); + + caffe_div(temp_.count(), top_data, temp_.cpu_data(), top_data); + } else { + caffe_cpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); // EX + + // subtract mean + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.cpu_data(), sum_multiplier_.cpu_data(), 0., + temp_.mutable_cpu_data()); + + caffe_add(temp_.count(), bottom_data, temp_.cpu_data(), top_data); + } +} + +template +void MVNLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + + int num; + if (this->layer_param_.mvn_param().across_channels()) + num = bottom[0]->num(); + else + num = bottom[0]->num() * bottom[0]->channels(); + + int dim = bottom[0]->count() / num; + Dtype eps = 1e-10; + + if (this->layer_param_.mvn_param().normalize_variance()) { + caffe_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_cpu_gemv(CblasNoTrans, num, dim, 1., bottom_diff, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + mean_.cpu_data(), sum_multiplier_.cpu_data(), 0., + bottom_diff); + caffe_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + caffe_cpu_gemv(CblasNoTrans, num, dim, 1., top_diff, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + mean_.cpu_data(), sum_multiplier_.cpu_data(), 1., + bottom_diff); + + caffe_cpu_axpby(temp_.count(), Dtype(1), top_diff, Dtype(-1. / dim), + bottom_diff); + + // put the squares of bottom into temp_ + caffe_powx(temp_.count(), bottom_data, Dtype(2), + temp_.mutable_cpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_cpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); // EX + caffe_cpu_gemv(CblasNoTrans, num, dim, 1. / dim, temp_.cpu_data(), + sum_multiplier_.cpu_data(), 0., + variance_.mutable_cpu_data()); // E(X^2) + caffe_powx(mean_.count(), mean_.cpu_data(), Dtype(2), + temp_.mutable_cpu_data()); // (EX)^2 + caffe_sub(mean_.count(), variance_.cpu_data(), temp_.cpu_data(), + variance_.mutable_cpu_data()); // variance + + // normalize variance + caffe_powx(variance_.count(), variance_.cpu_data(), Dtype(0.5), + variance_.mutable_cpu_data()); + + caffe_add_scalar(variance_.count(), eps, variance_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + variance_.cpu_data(), sum_multiplier_.cpu_data(), 0., + temp_.mutable_cpu_data()); + + caffe_div(temp_.count(), bottom_diff, temp_.cpu_data(), bottom_diff); + } else { + caffe_copy(temp_.count(), top_diff, bottom_diff); + } +} + + +#ifdef CPU_ONLY +STUB_GPU(MVNLayer); +#endif + +INSTANTIATE_CLASS(MVNLayer); +REGISTER_LAYER_CLASS(MVN, MVNLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/mvn_layer.cu b/caffe-crfrnn/src/caffe/layers/mvn_layer.cu new file mode 100644 index 00000000..0667f503 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/mvn_layer.cu @@ -0,0 +1,145 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void MVNLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int num; + if (this->layer_param_.mvn_param().across_channels()) + num = bottom[0]->num(); + else + num = bottom[0]->num() * bottom[0]->channels(); + + int dim = bottom[0]->count() / num; + + if (this->layer_param_.mvn_param().normalize_variance()) { + // put the squares of bottom into temp_ + caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2), + temp_.mutable_gpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_gpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); // EX + caffe_gpu_gemv(CblasNoTrans, num, dim, 1. / dim, temp_.gpu_data(), + sum_multiplier_.gpu_data(), 0., + variance_.mutable_gpu_data()); // E(X^2) + caffe_gpu_powx(mean_.count(), mean_.gpu_data(), Dtype(2), + temp_.mutable_gpu_data()); // (EX)^2 + caffe_gpu_sub(mean_.count(), variance_.gpu_data(), temp_.gpu_data(), + variance_.mutable_gpu_data()); // variance + + Dtype eps = 1e-10; + + // do mean and variance normalization + // subtract mean + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.gpu_data(), sum_multiplier_.gpu_data(), 0., + temp_.mutable_gpu_data()); + + caffe_gpu_add(temp_.count(), bottom_data, temp_.gpu_data(), top_data); + + // normalize variance + caffe_gpu_powx(variance_.count(), variance_.gpu_data(), Dtype(0.5), + variance_.mutable_gpu_data()); + + caffe_gpu_add_scalar(variance_.count(), eps, variance_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + variance_.gpu_data(), sum_multiplier_.gpu_data(), 0., + temp_.mutable_gpu_data()); + + caffe_gpu_div(temp_.count(), top_data, temp_.gpu_data(), top_data); + } else { + caffe_gpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); // EX + + // subtract mean + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, -1., + mean_.gpu_data(), sum_multiplier_.gpu_data(), 0., + temp_.mutable_gpu_data()); + + caffe_gpu_add(temp_.count(), bottom_data, temp_.gpu_data(), top_data); + } +} + +template +void MVNLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + int num; + if (this->layer_param_.mvn_param().across_channels()) + num = bottom[0]->num(); + else + num = bottom[0]->num() * bottom[0]->channels(); + + int dim = bottom[0]->count() / num; + + Dtype eps = 1e-10; + + if (this->layer_param_.mvn_param().normalize_variance()) { + caffe_gpu_mul(temp_.count(), top_data, top_diff, bottom_diff); + caffe_gpu_gemv(CblasNoTrans, num, dim, 1., bottom_diff, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + mean_.gpu_data(), sum_multiplier_.gpu_data(), 0., + bottom_diff); + caffe_gpu_mul(temp_.count(), top_data, bottom_diff, bottom_diff); + + caffe_gpu_gemv(CblasNoTrans, num, dim, 1., top_diff, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + mean_.gpu_data(), sum_multiplier_.gpu_data(), 1., + bottom_diff); + + caffe_gpu_axpby(temp_.count(), Dtype(1), top_diff, Dtype(-1. / dim), + bottom_diff); + + // put the squares of bottom into temp_ + caffe_gpu_powx(temp_.count(), bottom_data, Dtype(2), + temp_.mutable_gpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + caffe_gpu_gemv(CblasNoTrans, num, dim, 1. / dim, bottom_data, + sum_multiplier_.gpu_data(), 0., mean_.mutable_gpu_data()); // EX + caffe_gpu_gemv(CblasNoTrans, num, dim, 1. / dim, temp_.gpu_data(), + sum_multiplier_.gpu_data(), 0., + variance_.mutable_gpu_data()); // E(X^2) + caffe_gpu_powx(mean_.count(), mean_.gpu_data(), Dtype(2), + temp_.mutable_gpu_data()); // (EX)^2 + caffe_gpu_sub(mean_.count(), variance_.gpu_data(), temp_.gpu_data(), + variance_.mutable_gpu_data()); // variance + + // normalize variance + caffe_gpu_powx(variance_.count(), variance_.gpu_data(), Dtype(0.5), + variance_.mutable_gpu_data()); + + caffe_gpu_add_scalar(variance_.count(), eps, variance_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, num, dim, 1, 1., + variance_.gpu_data(), sum_multiplier_.gpu_data(), 0., + temp_.mutable_gpu_data()); + + caffe_gpu_div(temp_.count(), bottom_diff, temp_.gpu_data(), bottom_diff); + } else { + caffe_copy(temp_.count(), top_diff, bottom_diff); + } +} + + +INSTANTIATE_LAYER_GPU_FUNCS(MVNLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/neuron_layer.cpp b/caffe-crfrnn/src/caffe/layers/neuron_layer.cpp new file mode 100644 index 00000000..ba67b438 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/neuron_layer.cpp @@ -0,0 +1,16 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void NeuronLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->ReshapeLike(*bottom[0]); +} + +INSTANTIATE_CLASS(NeuronLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/pooling_layer.cpp b/caffe-crfrnn/src/caffe/layers/pooling_layer.cpp new file mode 100644 index 00000000..2bfbb01f --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/pooling_layer.cpp @@ -0,0 +1,318 @@ +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +using std::min; +using std::max; + +template +void PoolingLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + PoolingParameter pool_param = this->layer_param_.pooling_param(); + if (pool_param.global_pooling()) { + CHECK(!(pool_param.has_kernel_size() || + pool_param.has_kernel_h() || pool_param.has_kernel_w())) + << "With Global_pooling: true Filter size cannot specified"; + } else { + CHECK(!pool_param.has_kernel_size() != + !(pool_param.has_kernel_h() && pool_param.has_kernel_w())) + << "Filter size is kernel_size OR kernel_h and kernel_w; not both"; + CHECK(pool_param.has_kernel_size() || + (pool_param.has_kernel_h() && pool_param.has_kernel_w())) + << "For non-square filters both kernel_h and kernel_w are required."; + } + CHECK((!pool_param.has_pad() && pool_param.has_pad_h() + && pool_param.has_pad_w()) + || (!pool_param.has_pad_h() && !pool_param.has_pad_w())) + << "pad is pad OR pad_h and pad_w are required."; + CHECK((!pool_param.has_stride() && pool_param.has_stride_h() + && pool_param.has_stride_w()) + || (!pool_param.has_stride_h() && !pool_param.has_stride_w())) + << "Stride is stride OR stride_h and stride_w are required."; + global_pooling_ = pool_param.global_pooling(); + if (global_pooling_) { + kernel_h_ = bottom[0]->height(); + kernel_w_ = bottom[0]->width(); + } else { + if (pool_param.has_kernel_size()) { + kernel_h_ = kernel_w_ = pool_param.kernel_size(); + } else { + kernel_h_ = pool_param.kernel_h(); + kernel_w_ = pool_param.kernel_w(); + } + } + CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero."; + CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero."; + if (!pool_param.has_pad_h()) { + pad_h_ = pad_w_ = pool_param.pad(); + } else { + pad_h_ = pool_param.pad_h(); + pad_w_ = pool_param.pad_w(); + } + if (!pool_param.has_stride_h()) { + stride_h_ = stride_w_ = pool_param.stride(); + } else { + stride_h_ = pool_param.stride_h(); + stride_w_ = pool_param.stride_w(); + } + if (global_pooling_) { + CHECK(pad_h_ == 0 && pad_w_ == 0 && stride_h_ == 1 && stride_w_ == 1) + << "With Global_pooling: true; only pad = 0 and stride = 1"; + } + if (pad_h_ != 0 || pad_w_ != 0) { + CHECK(this->layer_param_.pooling_param().pool() + == PoolingParameter_PoolMethod_AVE + || this->layer_param_.pooling_param().pool() + == PoolingParameter_PoolMethod_MAX) + << "Padding implemented only for average and max pooling."; + CHECK_LT(pad_h_, kernel_h_); + CHECK_LT(pad_w_, kernel_w_); + } +} + +template +void PoolingLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + if (global_pooling_) { + kernel_h_ = bottom[0]->height(); + kernel_w_ = bottom[0]->width(); + } + pooled_height_ = static_cast(ceil(static_cast( + height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1; + pooled_width_ = static_cast(ceil(static_cast( + width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1; + if (pad_h_ || pad_w_) { + // If we have padding, ensure that the last pooling starts strictly + // inside the image (instead of at the padding); otherwise clip the last. + if ((pooled_height_ - 1) * stride_h_ >= height_ + pad_h_) { + --pooled_height_; + } + if ((pooled_width_ - 1) * stride_w_ >= width_ + pad_w_) { + --pooled_width_; + } + CHECK_LT((pooled_height_ - 1) * stride_h_, height_ + pad_h_); + CHECK_LT((pooled_width_ - 1) * stride_w_, width_ + pad_w_); + } + top[0]->Reshape(bottom[0]->num(), channels_, pooled_height_, + pooled_width_); + if (top.size() > 1) { + top[1]->ReshapeLike(*top[0]); + } + // If max pooling, we will initialize the vector index part. + if (this->layer_param_.pooling_param().pool() == + PoolingParameter_PoolMethod_MAX && top.size() == 1) { + max_idx_.Reshape(bottom[0]->num(), channels_, pooled_height_, + pooled_width_); + } + // If stochastic pooling, we will initialize the random index part. + if (this->layer_param_.pooling_param().pool() == + PoolingParameter_PoolMethod_STOCHASTIC) { + rand_idx_.Reshape(bottom[0]->num(), channels_, pooled_height_, + pooled_width_); + } +} + +// TODO(Yangqing): Is there a faster way to do pooling in the channel-first +// case? +template +void PoolingLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int top_count = top[0]->count(); + // We'll output the mask to top[1] if it's of size >1. + const bool use_top_mask = top.size() > 1; + int* mask = NULL; // suppress warnings about uninitalized variables + Dtype* top_mask = NULL; + // Different pooling methods. We explicitly do the switch outside the for + // loop to save time, although this results in more code. + switch (this->layer_param_.pooling_param().pool()) { + case PoolingParameter_PoolMethod_MAX: + // Initialize + if (use_top_mask) { + top_mask = top[1]->mutable_cpu_data(); + caffe_set(top_count, Dtype(-1), top_mask); + } else { + mask = max_idx_.mutable_cpu_data(); + caffe_set(top_count, -1, mask); + } + caffe_set(top_count, Dtype(-FLT_MAX), top_data); + // The main loop + for (int n = 0; n < bottom[0]->num(); ++n) { + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_); + int wend = min(wstart + kernel_w_, width_); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = ph * pooled_width_ + pw; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int index = h * width_ + w; + if (bottom_data[index] > top_data[pool_index]) { + top_data[pool_index] = bottom_data[index]; + if (use_top_mask) { + top_mask[pool_index] = static_cast(index); + } else { + mask[pool_index] = index; + } + } + } + } + } + } + // compute offset + bottom_data += bottom[0]->offset(0, 1); + top_data += top[0]->offset(0, 1); + if (use_top_mask) { + top_mask += top[0]->offset(0, 1); + } else { + mask += top[0]->offset(0, 1); + } + } + } + break; + case PoolingParameter_PoolMethod_AVE: + for (int i = 0; i < top_count; ++i) { + top_data[i] = 0; + } + // The main loop + for (int n = 0; n < bottom[0]->num(); ++n) { + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height_); + wend = min(wend, width_); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + top_data[ph * pooled_width_ + pw] += + bottom_data[h * width_ + w]; + } + } + top_data[ph * pooled_width_ + pw] /= pool_size; + } + } + // compute offset + bottom_data += bottom[0]->offset(0, 1); + top_data += top[0]->offset(0, 1); + } + } + break; + case PoolingParameter_PoolMethod_STOCHASTIC: + NOT_IMPLEMENTED; + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } +} + +template +void PoolingLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + // Different pooling methods. We explicitly do the switch outside the for + // loop to save time, although this results in more codes. + caffe_set(bottom[0]->count(), Dtype(0), bottom_diff); + // We'll output the mask to top[1] if it's of size >1. + const bool use_top_mask = top.size() > 1; + const int* mask = NULL; // suppress warnings about uninitialized variables + const Dtype* top_mask = NULL; + switch (this->layer_param_.pooling_param().pool()) { + case PoolingParameter_PoolMethod_MAX: + // The main loop + if (use_top_mask) { + top_mask = top[1]->cpu_data(); + } else { + mask = max_idx_.cpu_data(); + } + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + const int index = ph * pooled_width_ + pw; + const int bottom_index = + use_top_mask ? top_mask[index] : mask[index]; + bottom_diff[bottom_index] += top_diff[index]; + } + } + bottom_diff += bottom[0]->offset(0, 1); + top_diff += top[0]->offset(0, 1); + if (use_top_mask) { + top_mask += top[0]->offset(0, 1); + } else { + mask += top[0]->offset(0, 1); + } + } + } + break; + case PoolingParameter_PoolMethod_AVE: + // The main loop + for (int n = 0; n < top[0]->num(); ++n) { + for (int c = 0; c < channels_; ++c) { + for (int ph = 0; ph < pooled_height_; ++ph) { + for (int pw = 0; pw < pooled_width_; ++pw) { + int hstart = ph * stride_h_ - pad_h_; + int wstart = pw * stride_w_ - pad_w_; + int hend = min(hstart + kernel_h_, height_ + pad_h_); + int wend = min(wstart + kernel_w_, width_ + pad_w_); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height_); + wend = min(wend, width_); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + bottom_diff[h * width_ + w] += + top_diff[ph * pooled_width_ + pw] / pool_size; + } + } + } + } + // offset + bottom_diff += bottom[0]->offset(0, 1); + top_diff += top[0]->offset(0, 1); + } + } + break; + case PoolingParameter_PoolMethod_STOCHASTIC: + NOT_IMPLEMENTED; + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } +} + + +#ifdef CPU_ONLY +STUB_GPU(PoolingLayer); +#endif + +INSTANTIATE_CLASS(PoolingLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/pooling_layer.cu b/caffe-crfrnn/src/caffe/layers/pooling_layer.cu new file mode 100644 index 00000000..0d3f2183 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/pooling_layer.cu @@ -0,0 +1,379 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, Dtype* top_data, + int* mask, Dtype* top_mask) { + CUDA_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + Dtype maxval = -FLT_MAX; + int maxidx = -1; + bottom_data += (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (bottom_data[h * width + w] > maxval) { + maxidx = h * width + w; + maxval = bottom_data[maxidx]; + } + } + } + top_data[index] = maxval; + if (mask) { + mask[index] = maxidx; + } else { + top_mask[index] = maxidx; + } + } +} + +template +__global__ void AvePoolForward(const int nthreads, const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, Dtype* top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + Dtype aveval = 0; + bottom_data += (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_data[h * width + w]; + } + } + top_data[index] = aveval / pool_size; + } +} + +template +__global__ void StoPoolForwardTrain(const int nthreads, + const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, Dtype* rand_idx, Dtype* top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h; + int hend = min(hstart + kernel_h, height); + int wstart = pw * stride_w; + int wend = min(wstart + kernel_w, width); + Dtype cumsum = 0.; + bottom_data += (n * channels + c) * height * width; + // First pass: get sum + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + cumsum += bottom_data[h * width + w]; + } + } + float thres = rand_idx[index] * cumsum; + // Second pass: get value, and set index. + cumsum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + cumsum += bottom_data[h * width + w]; + if (cumsum >= thres) { + rand_idx[index] = ((n * channels + c) * height + h) * width + w; + top_data[index] = bottom_data[h * width + w]; + return; + } + } + } + } +} + + +template +__global__ void StoPoolForwardTest(const int nthreads, + const Dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, Dtype* top_data) { + CUDA_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h; + int hend = min(hstart + kernel_h, height); + int wstart = pw * stride_w; + int wend = min(wstart + kernel_w, width); + // We set cumsum to be 0 to avoid divide-by-zero problems + Dtype cumsum = FLT_MIN; + Dtype cumvalues = 0.; + bottom_data += (n * channels + c) * height * width; + // First pass: get sum + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + cumsum += bottom_data[h * width + w]; + cumvalues += bottom_data[h * width + w] * bottom_data[h * width + w]; + } + } + top_data[index] = cumvalues / cumsum; + } +} + + +template +void PoolingLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + int count = top[0]->count(); + // We'll output the mask to top[1] if it's of size >1. + const bool use_top_mask = top.size() > 1; + int* mask = NULL; + Dtype* top_mask = NULL; + switch (this->layer_param_.pooling_param().pool()) { + case PoolingParameter_PoolMethod_MAX: + if (use_top_mask) { + top_mask = top[1]->mutable_gpu_data(); + } else { + mask = max_idx_.mutable_gpu_data(); + } + // NOLINT_NEXT_LINE(whitespace/operators) + MaxPoolForward<<>>( + count, bottom_data, bottom[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, kernel_h_, + kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data, + mask, top_mask); + break; + case PoolingParameter_PoolMethod_AVE: + // NOLINT_NEXT_LINE(whitespace/operators) + AvePoolForward<<>>( + count, bottom_data, bottom[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, kernel_h_, + kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data); + break; + case PoolingParameter_PoolMethod_STOCHASTIC: + if (Caffe::phase() == Caffe::TRAIN) { + // We need to create the random index as well. + caffe_gpu_rng_uniform(count, Dtype(0), Dtype(1), + rand_idx_.mutable_gpu_data()); + // NOLINT_NEXT_LINE(whitespace/operators) + StoPoolForwardTrain<<>>( + count, bottom_data, bottom[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, kernel_h_, + kernel_w_, stride_h_, stride_w_, + rand_idx_.mutable_gpu_data(), top_data); + } else { + // NOLINT_NEXT_LINE(whitespace/operators) + StoPoolForwardTest<<>>( + count, bottom_data, bottom[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, kernel_h_, + kernel_w_, stride_h_, stride_w_, top_data); + } + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDA_POST_KERNEL_CHECK; +} + + +template +__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff, + const int* mask, const Dtype* top_mask, const int num, const int channels, + const int height, const int width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = + (h + pad_h < kernel_h) ? 0 : (h + pad_h - kernel_h) / stride_h + 1; + int phend = min((h + pad_h) / stride_h + 1, pooled_height); + int pwstart = + (w + pad_w < kernel_w) ? 0 : (w + pad_w - kernel_w) / stride_w + 1; + int pwend = min((w + pad_w) / stride_w + 1, pooled_width); + Dtype gradient = 0; + int offset = (n * channels + c) * pooled_height * pooled_width; + top_diff += offset; + if (mask) { + mask += offset; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (mask[ph * pooled_width + pw] == h * width + w) { + gradient += top_diff[ph * pooled_width + pw]; + } + } + } + } else { + top_mask += offset; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (top_mask[ph * pooled_width + pw] == h * width + w) { + gradient += top_diff[ph * pooled_width + pw]; + } + } + } + } + bottom_diff[index] = gradient; + } +} + +template +__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + int w = index % width + pad_w; + int h = (index / width) % height + pad_h; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + int phend = min(h / stride_h + 1, pooled_height); + int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + int pwend = min(w / stride_w + 1, pooled_width); + Dtype gradient = 0; + top_diff += (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + gradient += top_diff[ph * pooled_width + pw] / pool_size; + } + } + bottom_diff[index] = gradient; + } +} + + +template +__global__ void StoPoolBackward(const int nthreads, + const Dtype* rand_idx, const Dtype* top_diff, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, Dtype* bottom_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + int phend = min(h / stride_h + 1, pooled_height); + int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + int pwend = min(w / stride_w + 1, pooled_width); + Dtype gradient = 0; + rand_idx += (n * channels + c) * pooled_height * pooled_width; + top_diff += (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + gradient += top_diff[ph * pooled_width + pw] * + (index == static_cast(rand_idx[ph * pooled_width + pw])); + } + } + bottom_diff[index] = gradient; + } +} + + +template +void PoolingLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { + return; + } + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + caffe_gpu_set(count, Dtype(0.), bottom_diff); + // We'll output the mask to top[1] if it's of size >1. + const bool use_top_mask = top.size() > 1; + const int* mask = NULL; + const Dtype* top_mask = NULL; + switch (this->layer_param_.pooling_param().pool()) { + case PoolingParameter_PoolMethod_MAX: + if (use_top_mask) { + top_mask = top[1]->gpu_data(); + } else { + mask = max_idx_.gpu_data(); + } + // NOLINT_NEXT_LINE(whitespace/operators) + MaxPoolBackward<<>>( + count, top_diff, mask, top_mask, top[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, + kernel_h_, kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, + bottom_diff); + break; + case PoolingParameter_PoolMethod_AVE: + // NOLINT_NEXT_LINE(whitespace/operators) + AvePoolBackward<<>>( + count, top_diff, top[0]->num(), channels_, + height_, width_, pooled_height_, pooled_width_, kernel_h_, + kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, bottom_diff); + break; + case PoolingParameter_PoolMethod_STOCHASTIC: + // NOLINT_NEXT_LINE(whitespace/operators) + StoPoolBackward<<>>( + count, rand_idx_.gpu_data(), top_diff, + top[0]->num(), channels_, height_, width_, pooled_height_, + pooled_width_, kernel_h_, kernel_w_, stride_h_, stride_w_, + bottom_diff); + break; + default: + LOG(FATAL) << "Unknown pooling method."; + } + CUDA_POST_KERNEL_CHECK; +} + + +INSTANTIATE_LAYER_GPU_FUNCS(PoolingLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/power_layer.cpp b/caffe-crfrnn/src/caffe/layers/power_layer.cpp new file mode 100644 index 00000000..69bd120e --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/power_layer.cpp @@ -0,0 +1,103 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void PowerLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + power_ = this->layer_param_.power_param().power(); + scale_ = this->layer_param_.power_param().scale(); + shift_ = this->layer_param_.power_param().shift(); + diff_scale_ = power_ * scale_; +} + +// Compute y = (shift + scale * x)^power +template +void PowerLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + // Special case where we can ignore the input: scale or power is 0. + if (diff_scale_ == Dtype(0)) { + Dtype value = (power_ == 0) ? Dtype(1) : pow(shift_, power_); + caffe_set(count, value, top_data); + return; + } + const Dtype* bottom_data = bottom[0]->cpu_data(); + caffe_copy(count, bottom_data, top_data); + if (scale_ != Dtype(1)) { + caffe_scal(count, scale_, top_data); + } + if (shift_ != Dtype(0)) { + caffe_add_scalar(count, shift_, top_data); + } + if (power_ != Dtype(1)) { + caffe_powx(count, top_data, power_, top_data); + } +} + +template +void PowerLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + const Dtype* top_diff = top[0]->cpu_diff(); + if (diff_scale_ == Dtype(0) || power_ == Dtype(1)) { + caffe_set(count, diff_scale_, bottom_diff); + } else { + const Dtype* bottom_data = bottom[0]->cpu_data(); + // Compute dy/dx = scale * power * (shift + scale * x)^(power - 1) + // = diff_scale * y / (shift + scale * x) + if (power_ == Dtype(2)) { + // Special case for y = (shift + scale * x)^2 + // -> dy/dx = 2 * scale * (shift + scale * x) + // = diff_scale * shift + diff_scale * scale * x + caffe_cpu_axpby(count, diff_scale_ * scale_, bottom_data, + Dtype(0), bottom_diff); + if (shift_ != Dtype(0)) { + caffe_add_scalar(count, diff_scale_ * shift_, bottom_diff); + } + } else if (shift_ == Dtype(0)) { + // Special case for y = (scale * x)^power + // -> dy/dx = scale * power * (scale * x)^(power - 1) + // = scale * power * (scale * x)^power * (scale * x)^(-1) + // = power * y / x + const Dtype* top_data = top[0]->cpu_data(); + caffe_div(count, top_data, bottom_data, bottom_diff); + caffe_scal(count, power_, bottom_diff); + } else { + caffe_copy(count, bottom_data, bottom_diff); + if (scale_ != Dtype(1)) { + caffe_scal(count, scale_, bottom_diff); + } + if (shift_ != Dtype(0)) { + caffe_add_scalar(count, shift_, bottom_diff); + } + const Dtype* top_data = top[0]->cpu_data(); + caffe_div(count, top_data, bottom_diff, bottom_diff); + if (diff_scale_ != Dtype(1)) { + caffe_scal(count, diff_scale_, bottom_diff); + } + } + } + if (diff_scale_ != Dtype(0)) { + caffe_mul(count, top_diff, bottom_diff, bottom_diff); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(PowerLayer); +#endif + +INSTANTIATE_CLASS(PowerLayer); +REGISTER_LAYER_CLASS(POWER, PowerLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/power_layer.cu b/caffe-crfrnn/src/caffe/layers/power_layer.cu new file mode 100644 index 00000000..90d94405 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/power_layer.cu @@ -0,0 +1,87 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void PowerLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // Special case where we can ignore the input: scale or power is 0. + if (diff_scale_ == Dtype(0)) { + Dtype value = (power_ == 0) ? Dtype(1) : pow(shift_, power_); + caffe_gpu_set(count, value, top_data); + return; + } + const Dtype* bottom_data = bottom[0]->gpu_data(); + caffe_copy(count, bottom_data, top_data); + if (scale_ != Dtype(1)) { + caffe_gpu_scal(count, scale_, top_data); + } + if (shift_ != Dtype(0)) { + caffe_gpu_add_scalar(count, shift_, top_data); + } + if (power_ != Dtype(1)) { + caffe_gpu_powx(count, top_data, power_, top_data); + } +} + +template +void PowerLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + const Dtype* top_diff = top[0]->gpu_diff(); + if (diff_scale_ == Dtype(0) || power_ == Dtype(1)) { + caffe_gpu_set(count, diff_scale_, bottom_diff); + } else { + const Dtype* bottom_data = bottom[0]->gpu_data(); + // Compute dy/dx = scale * power * (shift + scale * x)^(power - 1) + // = diff_scale * y / (shift + scale * x) + if (power_ == Dtype(2)) { + // Special case for y = (shift + scale * x)^2 + // -> dy/dx = 2 * scale * (shift + scale * x) + // = diff_scale * shift + diff_scale * scale * x + caffe_gpu_axpby(count, diff_scale_ * scale_, bottom_data, + Dtype(0), bottom_diff); + if (shift_ != Dtype(0)) { + caffe_gpu_add_scalar(count, diff_scale_ * shift_, bottom_diff); + } + } else if (shift_ == Dtype(0)) { + // Special case for y = (scale * x)^power + // -> dy/dx = scale * power * (scale * x)^(power - 1) + // = scale * power * (scale * x)^power * (scale * x)^(-1) + // = power * y / x + const Dtype* top_data = top[0]->gpu_data(); + caffe_gpu_div(count, top_data, bottom_data, bottom_diff); + caffe_gpu_scal(count, power_, bottom_diff); + } else { + caffe_copy(count, bottom_data, bottom_diff); + if (scale_ != Dtype(1)) { + caffe_gpu_scal(count, scale_, bottom_diff); + } + if (shift_ != Dtype(0)) { + caffe_gpu_add_scalar(count, shift_, bottom_diff); + } + const Dtype* top_data = top[0]->gpu_data(); + caffe_gpu_div(count, top_data, bottom_diff, bottom_diff); + if (diff_scale_ != Dtype(1)) { + caffe_gpu_scal(count, diff_scale_, bottom_diff); + } + } + } + caffe_gpu_mul(count, top_diff, bottom_diff, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(PowerLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/relu_layer.cpp b/caffe-crfrnn/src/caffe/layers/relu_layer.cpp new file mode 100644 index 00000000..7d5e6034 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/relu_layer.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void ReLULayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); + for (int i = 0; i < count; ++i) { + top_data[i] = std::max(bottom_data[i], Dtype(0)) + + negative_slope * std::min(bottom_data[i], Dtype(0)); + } +} + +template +void ReLULayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); + for (int i = 0; i < count; ++i) { + bottom_diff[i] = top_diff[i] * ((bottom_data[i] > 0) + + negative_slope * (bottom_data[i] <= 0)); + } + } +} + + +#ifdef CPU_ONLY +STUB_GPU(ReLULayer); +#endif + +INSTANTIATE_CLASS(ReLULayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/relu_layer.cu b/caffe-crfrnn/src/caffe/layers/relu_layer.cu new file mode 100644 index 00000000..b8924c85 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/relu_layer.cu @@ -0,0 +1,65 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out, + Dtype negative_slope) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = in[index] > 0 ? in[index] : in[index] * negative_slope; + } +} + +template +void ReLULayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); + // NOLINT_NEXT_LINE(whitespace/operators) + ReLUForward<<>>( + count, bottom_data, top_data, negative_slope); + CUDA_POST_KERNEL_CHECK; + // << " count: " << count << " bottom_data: " + // << (unsigned long)bottom_data + // << " top_data: " << (unsigned long)top_data + // << " blocks: " << CAFFE_GET_BLOCKS(count) + // << " threads: " << CAFFE_CUDA_NUM_THREADS; +} + +template +__global__ void ReLUBackward(const int n, const Dtype* in_diff, + const Dtype* in_data, Dtype* out_diff, Dtype negative_slope) { + CUDA_KERNEL_LOOP(index, n) { + out_diff[index] = in_diff[index] * ((in_data[index] > 0) + + (in_data[index] <= 0) * negative_slope); + } +} + +template +void ReLULayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + Dtype negative_slope = this->layer_param_.relu_param().negative_slope(); + // NOLINT_NEXT_LINE(whitespace/operators) + ReLUBackward<<>>( + count, top_diff, bottom_data, bottom_diff, negative_slope); + CUDA_POST_KERNEL_CHECK; + } +} + + +INSTANTIATE_LAYER_GPU_FUNCS(ReLULayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp new file mode 100644 index 00000000..d1e327a5 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -0,0 +1,79 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SigmoidCrossEntropyLossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + LossLayer::LayerSetUp(bottom, top); + sigmoid_bottom_vec_.clear(); + sigmoid_bottom_vec_.push_back(bottom[0]); + sigmoid_top_vec_.clear(); + sigmoid_top_vec_.push_back(sigmoid_output_.get()); + sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_); +} + +template +void SigmoidCrossEntropyLossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << + "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count."; + sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_); +} + +template +void SigmoidCrossEntropyLossLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + // The forward pass computes the sigmoid outputs. + sigmoid_bottom_vec_[0] = bottom[0]; + sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); + // Compute the loss (negative log likelihood) + const int count = bottom[0]->count(); + const int num = bottom[0]->num(); + // Stable version of loss computation from input data + const Dtype* input_data = bottom[0]->cpu_data(); + const Dtype* target = bottom[1]->cpu_data(); + Dtype loss = 0; + for (int i = 0; i < count; ++i) { + loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + } + top[0]->mutable_cpu_data()[0] = loss / num; +} + +template +void SigmoidCrossEntropyLossLayer::Backward_cpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + // First, compute the diff + const int count = bottom[0]->count(); + const int num = bottom[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); + const Dtype* target = bottom[1]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_sub(count, sigmoid_output_data, target, bottom_diff); + // Scale down gradient + const Dtype loss_weight = top[0]->cpu_diff()[0]; + caffe_scal(count, loss_weight / num, bottom_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(SigmoidCrossEntropyLossLayer); +#endif + +INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); +REGISTER_LAYER_CLASS(SIGMOID_CROSS_ENTROPY_LOSS, SigmoidCrossEntropyLossLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu new file mode 100644 index 00000000..d9db4af6 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -0,0 +1,57 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SigmoidCrossEntropyLossLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + // The forward pass computes the sigmoid outputs. + sigmoid_bottom_vec_[0] = bottom[0]; + sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); + // Compute the loss (negative log likelihood) + const int count = bottom[0]->count(); + const int num = bottom[0]->num(); + // Stable version of loss computation from input data + const Dtype* input_data = bottom[0]->cpu_data(); + const Dtype* target = bottom[1]->cpu_data(); + Dtype loss = 0; + for (int i = 0; i < count; ++i) { + loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + } + top[0]->mutable_cpu_data()[0] = loss / num; +} + +template +void SigmoidCrossEntropyLossLayer::Backward_gpu( + const vector*>& top, const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + // First, compute the diff + const int count = bottom[0]->count(); + const int num = bottom[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); + const Dtype* target = bottom[1]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + caffe_copy(count, sigmoid_output_data, bottom_diff); + caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff); + // Scale down gradient + const Dtype loss_weight = top[0]->cpu_diff()[0]; + caffe_gpu_scal(count, loss_weight / num, bottom_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SigmoidCrossEntropyLossLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cpp b/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cpp new file mode 100644 index 00000000..48c38490 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cpp @@ -0,0 +1,49 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +inline Dtype sigmoid(Dtype x) { + return 1. / (1. + exp(-x)); +} + +template +void SigmoidLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + top_data[i] = sigmoid(bottom_data[i]); + } +} + +template +void SigmoidLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + const Dtype sigmoid_x = top_data[i]; + bottom_diff[i] = top_diff[i] * sigmoid_x * (1. - sigmoid_x); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(SigmoidLayer); +#endif + +INSTANTIATE_CLASS(SigmoidLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cu b/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cu new file mode 100644 index 00000000..e1af0657 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/sigmoid_layer.cu @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void SigmoidForward(const int n, const Dtype* in, Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = 1. / (1. + exp(-in[index])); + } +} + +template +void SigmoidLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; + // << " count: " << count << " bottom_data: " + // << (unsigned long)bottom_data + // << " top_data: " << (unsigned long)top_data + // << " blocks: " << CAFFE_GET_BLOCKS(count) + // << " threads: " << CAFFE_CUDA_NUM_THREADS; +} + +template +__global__ void SigmoidBackward(const int n, const Dtype* in_diff, + const Dtype* out_data, Dtype* out_diff) { + CUDA_KERNEL_LOOP(index, n) { + const Dtype sigmoid_x = out_data[index]; + out_diff[index] = in_diff[index] * sigmoid_x * (1 - sigmoid_x); + } +} + +template +void SigmoidLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidBackward<<>>( + count, top_diff, top_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SigmoidLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/silence_layer.cpp b/caffe-crfrnn/src/caffe/layers/silence_layer.cpp new file mode 100644 index 00000000..9bd20574 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/silence_layer.cpp @@ -0,0 +1,26 @@ +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SilenceLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < bottom.size(); ++i) { + if (propagate_down[i]) { + caffe_set(bottom[i]->count(), Dtype(0), + bottom[i]->mutable_cpu_data()); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(SilenceLayer); +#endif + +INSTANTIATE_CLASS(SilenceLayer); +REGISTER_LAYER_CLASS(SILENCE, SilenceLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/silence_layer.cu b/caffe-crfrnn/src/caffe/layers/silence_layer.cu new file mode 100644 index 00000000..8d044ee7 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/silence_layer.cu @@ -0,0 +1,28 @@ +#include + +#include "caffe/common_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SilenceLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + // Do nothing. +} + +template +void SilenceLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + for (int i = 0; i < bottom.size(); ++i) { + if (propagate_down[i]) { + caffe_gpu_set(bottom[i]->count(), Dtype(0), + bottom[i]->mutable_gpu_data()); + } + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(SilenceLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/slice_layer.cpp b/caffe-crfrnn/src/caffe/layers/slice_layer.cpp new file mode 100644 index 00000000..60a5ecfa --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/slice_layer.cpp @@ -0,0 +1,141 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SliceLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + const SliceParameter& slice_param = this->layer_param_.slice_param(); + slice_dim_ = slice_param.slice_dim(); + CHECK_GE(slice_dim_, 0); + CHECK_LE(slice_dim_, 1) << "Can only slice num and channels"; + slice_point_.clear(); + std::copy(slice_param.slice_point().begin(), + slice_param.slice_point().end(), + std::back_inserter(slice_point_)); +} + +template +void SliceLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + count_ = 0; + num_ = bottom[0]->num(); + channels_ = bottom[0]->channels(); + height_ = bottom[0]->height(); + width_ = bottom[0]->width(); + if (slice_point_.size() != 0) { + CHECK_EQ(slice_point_.size(), top.size() - 1); + if (slice_dim_ == 0) { + CHECK_LE(top.size(), num_); + } else { + CHECK_LE(top.size(), channels_); + } + int prev = 0; + vector slices; + for (int i = 0; i < slice_point_.size(); ++i) { + CHECK_GT(slice_point_[i], prev); + slices.push_back(slice_point_[i] - prev); + prev = slice_point_[i]; + } + if (slice_dim_ == 0) { + slices.push_back(num_ - prev); + for (int i = 0; i < top.size(); ++i) { + top[i]->Reshape(slices[i], channels_, height_, width_); + count_ += top[i]->count(); + } + } else { + slices.push_back(channels_ - prev); + for (int i = 0; i < top.size(); ++i) { + top[i]->Reshape(num_, slices[i], height_, width_); + count_ += top[i]->count(); + } + } + } else { + if (slice_dim_ == 0) { + CHECK_EQ(num_ % top.size(), 0) + << "Number of top blobs (" << top.size() << ") " + << "should evenly divide input num ( " << num_ << ")"; + num_ = num_ / top.size(); + } else { + CHECK_EQ(channels_ % top.size(), 0) + << "Number of top blobs (" << top.size() << ") " + << "should evenly divide input channels ( " << channels_ << ")"; + channels_ = channels_ / top.size(); + } + for (int i = 0; i < top.size(); ++i) { + top[i]->Reshape(num_, channels_, height_, width_); + count_ += top[i]->count(); + } + } + CHECK_EQ(count_, bottom[0]->count()); +} + +template +void SliceLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->mutable_cpu_data(); + if (slice_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + Dtype* top_data = blob->mutable_cpu_data(); + caffe_copy(blob->count(), bottom_data + bottom[0]->offset(offset_num), + top_data); + offset_num += blob->num(); + } + } else if (slice_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + Dtype* top_data = blob->mutable_cpu_data(); + const int num_elem = blob->channels() * blob->height() * blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, bottom_data + bottom[0]->offset(n, offset_channel), + top_data + blob->offset(n)); + } + offset_channel += blob->channels(); + } + } // slice_dim_ is guaranteed to be 0 or 1 by SetUp. +} + +template +void SliceLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + if (slice_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + const Dtype* top_diff = blob->cpu_diff(); + caffe_copy(blob->count(), top_diff, + bottom_diff + bottom[0]->offset(offset_num)); + offset_num += blob->num(); + } + } else if (slice_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + const Dtype* top_diff = blob->cpu_diff(); + const int num_elem = blob->channels() * blob->height() * blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, top_diff + blob->offset(n), + bottom_diff + bottom[0]->offset(n, offset_channel)); + } + offset_channel += blob->channels(); + } + } // slice_dim_ is guaranteed to be 0 or 1 by SetUp. +} + +#ifdef CPU_ONLY +STUB_GPU(SliceLayer); +#endif + +INSTANTIATE_CLASS(SliceLayer); +REGISTER_LAYER_CLASS(SLICE, SliceLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/slice_layer.cu b/caffe-crfrnn/src/caffe/layers/slice_layer.cu new file mode 100644 index 00000000..b5c5e615 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/slice_layer.cu @@ -0,0 +1,68 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SliceLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->mutable_gpu_data(); + if (slice_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + Dtype* top_data = blob->mutable_gpu_data(); + caffe_copy(blob->count(), bottom_data + bottom[0]->offset(offset_num), + top_data); + offset_num += blob->num(); + } + } else if (slice_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + Dtype* top_data = blob->mutable_gpu_data(); + const int num_elem = blob->channels() * blob->height() * blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, bottom_data + bottom[0]->offset(n, offset_channel), + top_data + blob->offset(n)); + } + offset_channel += blob->channels(); + } + } // slice_dim_ is guaranteed to be 0 or 1 by SetUp. +} + +template +void SliceLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + if (slice_dim_ == 0) { + int offset_num = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + const Dtype* top_diff = blob->gpu_diff(); + caffe_copy(blob->count(), top_diff, + bottom_diff + bottom[0]->offset(offset_num)); + offset_num += blob->num(); + } + } else if (slice_dim_ == 1) { + int offset_channel = 0; + for (int i = 0; i < top.size(); ++i) { + Blob* blob = top[i]; + const Dtype* top_diff = blob->gpu_diff(); + const int num_elem = blob->channels() * blob->height() * blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, top_diff + blob->offset(n), + bottom_diff + bottom[0]->offset(n, offset_channel)); + } + offset_channel += blob->channels(); + } + } // slice_dim_ is guaranteed to be 0 or 1 by SetUp. +} + +INSTANTIATE_LAYER_GPU_FUNCS(SliceLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/softmax_layer.cpp b/caffe-crfrnn/src/caffe/layers/softmax_layer.cpp new file mode 100644 index 00000000..c7b09fff --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/softmax_layer.cpp @@ -0,0 +1,95 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SoftmaxLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1); + Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data(); + for (int i = 0; i < sum_multiplier_.count(); ++i) { + multiplier_data[i] = 1.; + } + scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width()); +} + +template +void SoftmaxLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + Dtype* scale_data = scale_.mutable_cpu_data(); + int num = bottom[0]->num(); + int channels = bottom[0]->channels(); + int dim = bottom[0]->count() / bottom[0]->num(); + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + caffe_copy(bottom[0]->count(), bottom_data, top_data); + // We need to subtract the max to avoid numerical issues, compute the exp, + // and then normalize. + for (int i = 0; i < num; ++i) { + // initialize scale_data to the first plane + caffe_copy(spatial_dim, bottom_data + i * dim, scale_data); + for (int j = 0; j < channels; j++) { + for (int k = 0; k < spatial_dim; k++) { + scale_data[k] = std::max(scale_data[k], + bottom_data[i * dim + j * spatial_dim + k]); + } + } + // subtraction + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels, spatial_dim, + 1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim); + // exponentiation + caffe_exp(dim, top_data + i * dim, top_data + i * dim); + // sum after exp + caffe_cpu_gemv(CblasTrans, channels, spatial_dim, 1., + top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data); + // division + for (int j = 0; j < channels; j++) { + caffe_div(spatial_dim, top_data + top[0]->offset(i, j), scale_data, + top_data + top[0]->offset(i, j)); + } + } +} + +template +void SoftmaxLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = top[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + Dtype* scale_data = scale_.mutable_cpu_data(); + int num = top[0]->num(); + int channels = top[0]->channels(); + int dim = top[0]->count() / top[0]->num(); + int spatial_dim = top[0]->height() * top[0]->width(); + caffe_copy(top[0]->count(), top_diff, bottom_diff); + for (int i = 0; i < num; ++i) { + // compute dot(top_diff, top_data) and subtract them from the bottom diff + for (int k = 0; k < spatial_dim; ++k) { + scale_data[k] = caffe_cpu_strided_dot(channels, + bottom_diff + i * dim + k, spatial_dim, + top_data + i * dim + k, spatial_dim); + } + // subtraction + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1, + -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim); + } + // elementwise multiplication + caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff); +} + + +#ifdef CPU_ONLY +STUB_GPU(SoftmaxLayer); +#endif + +INSTANTIATE_CLASS(SoftmaxLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/softmax_layer.cu b/caffe-crfrnn/src/caffe/layers/softmax_layer.cu new file mode 100644 index 00000000..292ad2b3 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/softmax_layer.cu @@ -0,0 +1,154 @@ +#include +#include +#include + +#include "thrust/device_vector.h" + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void kernel_channel_max(const int num, const int channels, + const int spatial_dim, const Dtype* data, Dtype* out) { + CUDA_KERNEL_LOOP(index, num * spatial_dim) { + int n = index / spatial_dim; + int s = index % spatial_dim; + Dtype maxval = -FLT_MAX; + for (int c = 0; c < channels; ++c) { + maxval = max(data[(n * channels + c) * spatial_dim + s], maxval); + } + out[index] = maxval; + } +} + +template +__global__ void kernel_channel_subtract(const int num, const int channels, + const int spatial_dim, Dtype* data, const Dtype* channel_max) { + CUDA_KERNEL_LOOP(index, num * spatial_dim) { + int n = index / spatial_dim; + int s = index % spatial_dim; + for (int c = 0; c < channels; ++c) { + data[(n * channels + c) * spatial_dim + s] -= channel_max[index]; + } + } +} + +template +__global__ void kernel_exp(const int count, const Dtype* data, Dtype* out) { + CUDA_KERNEL_LOOP(index, count) { + out[index] = exp(data[index]); + } +} + +template +__global__ void kernel_channel_sum(const int num, const int channels, + const int spatial_dim, const Dtype* data, Dtype* channel_sum) { + CUDA_KERNEL_LOOP(index, num * spatial_dim) { + int n = index / spatial_dim; + int s = index % spatial_dim; + Dtype sum = 0; + for (int c = 0; c < channels; ++c) { + sum += data[(n * channels + c) * spatial_dim + s]; + } + channel_sum[index] = sum; + } +} + +template +__global__ void kernel_channel_div(const int num, const int channels, + const int spatial_dim, Dtype* data, const Dtype* channel_sum) { + CUDA_KERNEL_LOOP(index, num * spatial_dim) { + int n = index / spatial_dim; + int s = index % spatial_dim; + for (int c = 0; c < channels; ++c) { + data[(n * channels + c) * spatial_dim + s] /= channel_sum[index]; + } + } +} + +template +__global__ void kernel_channel_dot(const int num, const int channels, + const int spatial_dim, const Dtype* data_1, const Dtype* data_2, + Dtype* channel_dot) { + CUDA_KERNEL_LOOP(index, num * spatial_dim) { + int n = index / spatial_dim; + int s = index % spatial_dim; + Dtype dot = 0; + for (int c = 0; c < channels; ++c) { + dot += (data_1[(n * channels + c) * spatial_dim + s] + * data_2[(n * channels + c) * spatial_dim + s]); + } + channel_dot[index] = dot; + } +} + +template +void SoftmaxLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + Dtype* scale_data = scale_.mutable_gpu_data(); + int num = bottom[0]->num(); + int channels = bottom[0]->channels(); + int spatial_dim = bottom[0]->height() * bottom[0]->width(); + caffe_copy(bottom[0]->count(), bottom_data, top_data); + // We need to subtract the max to avoid numerical issues, compute the exp, + // and then normalize. + // compute max + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_max<<>>(num, channels, spatial_dim, top_data, + scale_data); + // subtract + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_subtract<<>>(num, channels, spatial_dim, top_data, + scale_data); + // exponentiate + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_exp<<>>(num * channels * spatial_dim, top_data, + top_data); + // sum after exp + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_sum<<>>(num, channels, spatial_dim, top_data, + scale_data); + // divide + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_div<<>>(num, channels, spatial_dim, top_data, + scale_data); +} + +template +void SoftmaxLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + Dtype* scale_data = scale_.mutable_gpu_data(); + int num = top[0]->num(); + int channels = top[0]->channels(); + int spatial_dim = top[0]->height() * top[0]->width(); + caffe_copy(top[0]->count(), top_diff, bottom_diff); + // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff. + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_dot<<>>(num, channels, spatial_dim, top_diff, top_data, + scale_data); + // NOLINT_NEXT_LINE(whitespace/operators) + kernel_channel_subtract<<>>(num, channels, spatial_dim, bottom_diff, + scale_data); + // elementwise multiplication + caffe_gpu_mul(top[0]->count(), bottom_diff, top_data, bottom_diff); +} + +INSTANTIATE_LAYER_GPU_FUNCS(SoftmaxLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/softmax_loss_layer.cpp b/caffe-crfrnn/src/caffe/layers/softmax_loss_layer.cpp new file mode 100644 index 00000000..55eb0918 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/softmax_loss_layer.cpp @@ -0,0 +1,117 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SoftmaxWithLossLayer::LayerSetUp( + const vector*>& bottom, const vector*>& top) { + LossLayer::LayerSetUp(bottom, top); + softmax_bottom_vec_.clear(); + softmax_bottom_vec_.push_back(bottom[0]); + softmax_top_vec_.clear(); + softmax_top_vec_.push_back(&prob_); + softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_); + + has_ignore_label_ = + this->layer_param_.loss_param().has_ignore_label(); + if (has_ignore_label_) { + ignore_label_ = this->layer_param_.loss_param().ignore_label(); + } + normalize_ = this->layer_param_.loss_param().normalize(); +} + +template +void SoftmaxWithLossLayer::Reshape( + const vector*>& bottom, const vector*>& top) { + LossLayer::Reshape(bottom, top); + softmax_layer_->Reshape(softmax_bottom_vec_, softmax_top_vec_); + if (top.size() >= 2) { + // softmax output + top[1]->ReshapeLike(*bottom[0]); + } +} + +template +void SoftmaxWithLossLayer::Forward_cpu( + const vector*>& bottom, const vector*>& top) { + // The forward pass computes the softmax prob values. + softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_); + const Dtype* prob_data = prob_.cpu_data(); + const Dtype* label = bottom[1]->cpu_data(); + int num = prob_.num(); + int dim = prob_.count() / num; + int spatial_dim = prob_.height() * prob_.width(); + int count = 0; + Dtype loss = 0; + for (int i = 0; i < num; ++i) { + for (int j = 0; j < spatial_dim; j++) { + const int label_value = static_cast(label[i * spatial_dim + j]); + if (has_ignore_label_ && label_value == ignore_label_) { + continue; + } + DCHECK_GE(label_value, 0); + DCHECK_LT(label_value, prob_.channels()); + loss -= log(std::max(prob_data[i * dim + label_value * spatial_dim + j], + Dtype(FLT_MIN))); + ++count; + } + } + if (normalize_) { + top[0]->mutable_cpu_data()[0] = loss / count; + } else { + top[0]->mutable_cpu_data()[0] = loss / num; + } + if (top.size() == 2) { + top[1]->ShareData(prob_); + } +} + +template +void SoftmaxWithLossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const Dtype* prob_data = prob_.cpu_data(); + caffe_copy(prob_.count(), prob_data, bottom_diff); + const Dtype* label = bottom[1]->cpu_data(); + int num = prob_.num(); + int dim = prob_.count() / num; + int spatial_dim = prob_.height() * prob_.width(); + int count = 0; + for (int i = 0; i < num; ++i) { + for (int j = 0; j < spatial_dim; ++j) { + const int label_value = static_cast(label[i * spatial_dim + j]); + if (has_ignore_label_ && label_value == ignore_label_) { + for (int c = 0; c < bottom[0]->channels(); ++c) { + bottom_diff[i * dim + c * spatial_dim + j] = 0; + } + } else { + bottom_diff[i * dim + label_value * spatial_dim + j] -= 1; + ++count; + } + } + } + // Scale gradient + const Dtype loss_weight = top[0]->cpu_diff()[0]; + if (normalize_) { + caffe_scal(prob_.count(), loss_weight / count, bottom_diff); + } else { + caffe_scal(prob_.count(), loss_weight / num, bottom_diff); + } + } +} + +INSTANTIATE_CLASS(SoftmaxWithLossLayer); +REGISTER_LAYER_CLASS(SOFTMAX_LOSS, SoftmaxWithLossLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/split_layer.cpp b/caffe-crfrnn/src/caffe/layers/split_layer.cpp new file mode 100644 index 00000000..51ac61f4 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/split_layer.cpp @@ -0,0 +1,60 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SplitLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + count_ = bottom[0]->count(); + for (int i = 0; i < top.size(); ++i) { + // Do not allow in-place computation in the SplitLayer. Instead, share data + // by reference in the forward pass, and keep separate diff allocations in + // the backward pass. (Technically, it should be possible to share the diff + // blob of the first split output with the input, but this seems to cause + // some strange effects in practice...) + CHECK_NE(top[i], bottom[0]) << this->type_name() << " Layer does not " + "allow in-place computation."; + top[i]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + CHECK_EQ(count_, top[i]->count()); + } +} + +template +void SplitLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + for (int i = 0; i < top.size(); ++i) { + top[i]->ShareData(*bottom[0]); + } +} + +template +void SplitLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + if (top.size() == 1) { + caffe_copy(count_, top[0]->cpu_diff(), bottom[0]->mutable_cpu_diff()); + return; + } + caffe_add(count_, top[0]->cpu_diff(), top[1]->cpu_diff(), + bottom[0]->mutable_cpu_diff()); + // Add remaining top blob diffs. + for (int i = 2; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_axpy(count_, Dtype(1.), top_diff, bottom_diff); + } +} + + +#ifdef CPU_ONLY +STUB_GPU(SplitLayer); +#endif + +INSTANTIATE_CLASS(SplitLayer); +REGISTER_LAYER_CLASS(SPLIT, SplitLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/split_layer.cu b/caffe-crfrnn/src/caffe/layers/split_layer.cu new file mode 100644 index 00000000..a4f5df26 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/split_layer.cu @@ -0,0 +1,38 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void SplitLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + for (int i = 0; i < top.size(); ++i) { + top[i]->ShareData(*bottom[0]); + } +} + +template +void SplitLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (!propagate_down[0]) { return; } + if (top.size() == 1) { + caffe_copy(count_, top[0]->gpu_diff(), bottom[0]->mutable_gpu_diff()); + return; + } + caffe_gpu_add(count_, top[0]->gpu_diff(), top[1]->gpu_diff(), + bottom[0]->mutable_gpu_diff()); + // Add remaining top blob diffs. + for (int i = 2; i < top.size(); ++i) { + const Dtype* top_diff = top[i]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + caffe_gpu_axpy(count_, Dtype(1.), top_diff, bottom_diff); + } +} + + +INSTANTIATE_LAYER_GPU_FUNCS(SplitLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/tanh_layer.cpp b/caffe-crfrnn/src/caffe/layers/tanh_layer.cpp new file mode 100644 index 00000000..ee5ed773 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/tanh_layer.cpp @@ -0,0 +1,46 @@ +// TanH neuron activation function layer. +// Adapted from ReLU layer code written by Yangqing Jia + +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void TanHLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + top_data[i] = tanh(bottom_data[i]); + } +} + +template +void TanHLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + const int count = bottom[0]->count(); + Dtype tanhx; + for (int i = 0; i < count; ++i) { + tanhx = top_data[i]; + bottom_diff[i] = top_diff[i] * (1 - tanhx * tanhx); + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(TanHLayer); +#endif + +INSTANTIATE_CLASS(TanHLayer); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/tanh_layer.cu b/caffe-crfrnn/src/caffe/layers/tanh_layer.cu new file mode 100644 index 00000000..ccd6e63e --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/tanh_layer.cu @@ -0,0 +1,59 @@ +// TanH neuron activation function layer. +// Adapted from ReLU layer code written by Yangqing Jia + +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void TanHForward(const int n, const Dtype* in, Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = tanh(in[index]); + } +} + +template +void TanHLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + TanHForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; +} + +template +__global__ void TanHBackward(const int n, const Dtype* in_diff, + const Dtype* out_data, Dtype* out_diff) { + CUDA_KERNEL_LOOP(index, n) { + Dtype tanhx = out_data[index]; + out_diff[index] = in_diff[index] * (1 - tanhx * tanhx); + } +} + +template +void TanHLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + if (propagate_down[0]) { + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + TanHBackward<<>>( + count, top_diff, top_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(TanHLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/threshold_layer.cpp b/caffe-crfrnn/src/caffe/layers/threshold_layer.cpp new file mode 100644 index 00000000..9e68c32d --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/threshold_layer.cpp @@ -0,0 +1,33 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + + +namespace caffe { + +template +void ThresholdLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + NeuronLayer::LayerSetUp(bottom, top); + threshold_ = this->layer_param_.threshold_param().threshold(); +} + +template +void ThresholdLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + top_data[i] = (bottom_data[i] > threshold_) ? Dtype(1) : Dtype(0); + } +} + +#ifdef CPU_ONLY +STUB_GPU_FORWARD(ThresholdLayer, Forward); +#endif + +INSTANTIATE_CLASS(ThresholdLayer); +REGISTER_LAYER_CLASS(THRESHOLD, ThresholdLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/threshold_layer.cu b/caffe-crfrnn/src/caffe/layers/threshold_layer.cu new file mode 100644 index 00000000..bfa7f159 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/threshold_layer.cu @@ -0,0 +1,33 @@ +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +__global__ void ThresholdForward(const int n, const Dtype threshold, + const Dtype* in, Dtype* out) { + CUDA_KERNEL_LOOP(index, n) { + out[index] = in[index] > threshold ? 1 : 0; + } +} + +template +void ThresholdLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + // NOLINT_NEXT_LINE(whitespace/operators) + ThresholdForward<<>>( + count, threshold_, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; +} + + +INSTANTIATE_LAYER_GPU_FORWARD(ThresholdLayer); + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/layers/window_data_layer.cpp b/caffe-crfrnn/src/caffe/layers/window_data_layer.cpp new file mode 100644 index 00000000..6287b385 --- /dev/null +++ b/caffe-crfrnn/src/caffe/layers/window_data_layer.cpp @@ -0,0 +1,467 @@ +#include + +#include +#include +#include +#include +#include + +#include "opencv2/core/core.hpp" +#include "opencv2/highgui/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" + +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/benchmark.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" + +// caffe.proto > LayerParameter > WindowDataParameter +// 'source' field specifies the window_file +// 'crop_size' indicates the desired warped size + +#if CV_VERSION_MAJOR == 3 +const int CV_LOAD_IMAGE_COLOR = cv::IMREAD_COLOR; +#endif + +namespace caffe { + +template +WindowDataLayer::~WindowDataLayer() { + this->JoinPrefetchThread(); +} + +template +void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + // LayerSetUp runs through the window_file and creates two structures + // that hold windows: one for foreground (object) windows and one + // for background (non-object) windows. We use an overlap threshold + // to decide which is which. + + // window_file format + // repeated: + // # image_index + // img_path (abs path) + // channels + // height + // width + // num_windows + // class_index overlap x1 y1 x2 y2 + + LOG(INFO) << "Window data layer:" << std::endl + << " foreground (object) overlap threshold: " + << this->layer_param_.window_data_param().fg_threshold() << std::endl + << " background (non-object) overlap threshold: " + << this->layer_param_.window_data_param().bg_threshold() << std::endl + << " foreground sampling fraction: " + << this->layer_param_.window_data_param().fg_fraction() << std::endl + << " cache_images: " + << this->layer_param_.window_data_param().cache_images() << std::endl + << " root_folder: " + << this->layer_param_.window_data_param().root_folder(); + + cache_images_ = this->layer_param_.window_data_param().cache_images(); + string root_folder = this->layer_param_.window_data_param().root_folder(); + + const bool prefetch_needs_rand = + this->transform_param_.mirror() || + this->transform_param_.crop_size(); + if (prefetch_needs_rand) { + const unsigned int prefetch_rng_seed = caffe_rng_rand(); + prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); + } else { + prefetch_rng_.reset(); + } + + std::ifstream infile(this->layer_param_.window_data_param().source().c_str()); + CHECK(infile.good()) << "Failed to open window file " + << this->layer_param_.window_data_param().source() << std::endl; + + map label_hist; + label_hist.insert(std::make_pair(0, 0)); + + string hashtag; + int image_index, channels; + if (!(infile >> hashtag >> image_index)) { + LOG(FATAL) << "Window file is empty"; + } + do { + CHECK_EQ(hashtag, "#"); + // read image path + string image_path; + infile >> image_path; + image_path = root_folder + image_path; + // read image dimensions + vector image_size(3); + infile >> image_size[0] >> image_size[1] >> image_size[2]; + channels = image_size[0]; + image_database_.push_back(std::make_pair(image_path, image_size)); + + if (cache_images_) { + Datum datum; + if (!ReadFileToDatum(image_path, &datum)) { + LOG(ERROR) << "Could not open or find file " << image_path; + return; + } + image_database_cache_.push_back(std::make_pair(image_path, datum)); + } + // read each box + int num_windows; + infile >> num_windows; + const float fg_threshold = + this->layer_param_.window_data_param().fg_threshold(); + const float bg_threshold = + this->layer_param_.window_data_param().bg_threshold(); + for (int i = 0; i < num_windows; ++i) { + int label, x1, y1, x2, y2; + float overlap; + infile >> label >> overlap >> x1 >> y1 >> x2 >> y2; + + vector window(WindowDataLayer::NUM); + window[WindowDataLayer::IMAGE_INDEX] = image_index; + window[WindowDataLayer::LABEL] = label; + window[WindowDataLayer::OVERLAP] = overlap; + window[WindowDataLayer::X1] = x1; + window[WindowDataLayer::Y1] = y1; + window[WindowDataLayer::X2] = x2; + window[WindowDataLayer::Y2] = y2; + + // add window to foreground list or background list + if (overlap >= fg_threshold) { + int label = window[WindowDataLayer::LABEL]; + CHECK_GT(label, 0); + fg_windows_.push_back(window); + label_hist.insert(std::make_pair(label, 0)); + label_hist[label]++; + } else if (overlap < bg_threshold) { + // background window, force label and overlap to 0 + window[WindowDataLayer::LABEL] = 0; + window[WindowDataLayer::OVERLAP] = 0; + bg_windows_.push_back(window); + label_hist[0]++; + } + } + + if (image_index % 100 == 0) { + LOG(INFO) << "num: " << image_index << " " + << image_path << " " + << image_size[0] << " " + << image_size[1] << " " + << image_size[2] << " " + << "windows to process: " << num_windows; + } + } while (infile >> hashtag >> image_index); + + LOG(INFO) << "Number of images: " << image_index+1; + + for (map::iterator it = label_hist.begin(); + it != label_hist.end(); ++it) { + LOG(INFO) << "class " << it->first << " has " << label_hist[it->first] + << " samples"; + } + + LOG(INFO) << "Amount of context padding: " + << this->layer_param_.window_data_param().context_pad(); + + LOG(INFO) << "Crop mode: " + << this->layer_param_.window_data_param().crop_mode(); + + // image + const int crop_size = this->transform_param_.crop_size(); + CHECK_GT(crop_size, 0); + const int batch_size = this->layer_param_.window_data_param().batch_size(); + top[0]->Reshape(batch_size, channels, crop_size, crop_size); + this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + + LOG(INFO) << "output data size: " << top[0]->num() << "," + << top[0]->channels() << "," << top[0]->height() << "," + << top[0]->width(); + // label + top[1]->Reshape(batch_size, 1, 1, 1); + this->prefetch_label_.Reshape(batch_size, 1, 1, 1); + + // data mean + has_mean_file_ = this->transform_param_.has_mean_file(); + has_mean_values_ = this->transform_param_.mean_value_size() > 0; + if (has_mean_file_) { + const string& mean_file = + this->transform_param_.mean_file(); + LOG(INFO) << "Loading mean file from" << mean_file; + BlobProto blob_proto; + ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); + data_mean_.FromProto(blob_proto); + } + if (has_mean_values_) { + CHECK(has_mean_file_ == false) << + "Cannot specify mean_file and mean_value at the same time"; + for (int c = 0; c < this->transform_param_.mean_value_size(); ++c) { + mean_values_.push_back(this->transform_param_.mean_value(c)); + } + CHECK(mean_values_.size() == 1 || mean_values_.size() == channels) << + "Specify either 1 mean_value or as many as channels: " << channels; + if (channels > 1 && mean_values_.size() == 1) { + // Replicate the mean_value for simplicity + for (int c = 1; c < channels; ++c) { + mean_values_.push_back(mean_values_[0]); + } + } + } +} + +template +unsigned int WindowDataLayer::PrefetchRand() { + CHECK(prefetch_rng_); + caffe::rng_t* prefetch_rng = + static_cast(prefetch_rng_->generator()); + return (*prefetch_rng)(); +} + +// Thread fetching the data +template +void WindowDataLayer::InternalThreadEntry() { + // At each iteration, sample N windows where N*p are foreground (object) + // windows and N*(1-p) are background (non-object) windows + CPUTimer batch_timer; + batch_timer.Start(); + double read_time = 0; + double trans_time = 0; + CPUTimer timer; + Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + const Dtype scale = this->layer_param_.window_data_param().scale(); + const int batch_size = this->layer_param_.window_data_param().batch_size(); + const int context_pad = this->layer_param_.window_data_param().context_pad(); + const int crop_size = this->transform_param_.crop_size(); + const bool mirror = this->transform_param_.mirror(); + const float fg_fraction = + this->layer_param_.window_data_param().fg_fraction(); + Dtype* mean = NULL; + int mean_off = 0; + int mean_width = 0; + int mean_height = 0; + if (this->has_mean_file_) { + mean = this->data_mean_.mutable_cpu_data(); + mean_off = (this->data_mean_.width() - crop_size) / 2; + mean_width = this->data_mean_.width(); + mean_height = this->data_mean_.height(); + } + cv::Size cv_crop_size(crop_size, crop_size); + const string& crop_mode = this->layer_param_.window_data_param().crop_mode(); + + bool use_square = (crop_mode == "square") ? true : false; + + // zero out batch + caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + + const int num_fg = static_cast(static_cast(batch_size) + * fg_fraction); + const int num_samples[2] = { batch_size - num_fg, num_fg }; + + int item_id = 0; + // sample from bg set then fg set + for (int is_fg = 0; is_fg < 2; ++is_fg) { + for (int dummy = 0; dummy < num_samples[is_fg]; ++dummy) { + // sample a window + timer.Start(); + const unsigned int rand_index = PrefetchRand(); + vector window = (is_fg) ? + fg_windows_[rand_index % fg_windows_.size()] : + bg_windows_[rand_index % bg_windows_.size()]; + + bool do_mirror = mirror && PrefetchRand() % 2; + + // load the image containing the window + pair > image = + image_database_[window[WindowDataLayer::IMAGE_INDEX]]; + + cv::Mat cv_img; + if (this->cache_images_) { + pair image_cached = + image_database_cache_[window[WindowDataLayer::IMAGE_INDEX]]; + cv_img = DecodeDatumToCVMat(image_cached.second); + } else { + cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR); + if (!cv_img.data) { + LOG(ERROR) << "Could not open or find file " << image.first; + return; + } + } + read_time += timer.MicroSeconds(); + timer.Start(); + const int channels = cv_img.channels(); + + // crop window out of image and warp it + int x1 = window[WindowDataLayer::X1]; + int y1 = window[WindowDataLayer::Y1]; + int x2 = window[WindowDataLayer::X2]; + int y2 = window[WindowDataLayer::Y2]; + + int pad_w = 0; + int pad_h = 0; + if (context_pad > 0 || use_square) { + // scale factor by which to expand the original region + // such that after warping the expanded region to crop_size x crop_size + // there's exactly context_pad amount of padding on each side + Dtype context_scale = static_cast(crop_size) / + static_cast(crop_size - 2*context_pad); + + // compute the expanded region + Dtype half_height = static_cast(y2-y1+1)/2.0; + Dtype half_width = static_cast(x2-x1+1)/2.0; + Dtype center_x = static_cast(x1) + half_width; + Dtype center_y = static_cast(y1) + half_height; + if (use_square) { + if (half_height > half_width) { + half_width = half_height; + } else { + half_height = half_width; + } + } + x1 = static_cast(round(center_x - half_width*context_scale)); + x2 = static_cast(round(center_x + half_width*context_scale)); + y1 = static_cast(round(center_y - half_height*context_scale)); + y2 = static_cast(round(center_y + half_height*context_scale)); + + // the expanded region may go outside of the image + // so we compute the clipped (expanded) region and keep track of + // the extent beyond the image + int unclipped_height = y2-y1+1; + int unclipped_width = x2-x1+1; + int pad_x1 = std::max(0, -x1); + int pad_y1 = std::max(0, -y1); + int pad_x2 = std::max(0, x2 - cv_img.cols + 1); + int pad_y2 = std::max(0, y2 - cv_img.rows + 1); + // clip bounds + x1 = x1 + pad_x1; + x2 = x2 - pad_x2; + y1 = y1 + pad_y1; + y2 = y2 - pad_y2; + CHECK_GT(x1, -1); + CHECK_GT(y1, -1); + CHECK_LT(x2, cv_img.cols); + CHECK_LT(y2, cv_img.rows); + + int clipped_height = y2-y1+1; + int clipped_width = x2-x1+1; + + // scale factors that would be used to warp the unclipped + // expanded region + Dtype scale_x = + static_cast(crop_size)/static_cast(unclipped_width); + Dtype scale_y = + static_cast(crop_size)/static_cast(unclipped_height); + + // size to warp the clipped expanded region to + cv_crop_size.width = + static_cast(round(static_cast(clipped_width)*scale_x)); + cv_crop_size.height = + static_cast(round(static_cast(clipped_height)*scale_y)); + pad_x1 = static_cast(round(static_cast(pad_x1)*scale_x)); + pad_x2 = static_cast(round(static_cast(pad_x2)*scale_x)); + pad_y1 = static_cast(round(static_cast(pad_y1)*scale_y)); + pad_y2 = static_cast(round(static_cast(pad_y2)*scale_y)); + + pad_h = pad_y1; + // if we're mirroring, we mirror the padding too (to be pedantic) + if (do_mirror) { + pad_w = pad_x2; + } else { + pad_w = pad_x1; + } + + // ensure that the warped, clipped region plus the padding fits in the + // crop_size x crop_size image (it might not due to rounding) + if (pad_h + cv_crop_size.height > crop_size) { + cv_crop_size.height = crop_size - pad_h; + } + if (pad_w + cv_crop_size.width > crop_size) { + cv_crop_size.width = crop_size - pad_w; + } + } + + cv::Rect roi(x1, y1, x2-x1+1, y2-y1+1); + cv::Mat cv_cropped_img = cv_img(roi); + cv::resize(cv_cropped_img, cv_cropped_img, + cv_crop_size, 0, 0, cv::INTER_LINEAR); + + // horizontal flip at random + if (do_mirror) { + cv::flip(cv_cropped_img, cv_cropped_img, 1); + } + + // copy the warped window into top_data + for (int h = 0; h < cv_cropped_img.rows; ++h) { + const uchar* ptr = cv_cropped_img.ptr(h); + int img_index = 0; + for (int w = 0; w < cv_cropped_img.cols; ++w) { + for (int c = 0; c < channels; ++c) { + int top_index = ((item_id * channels + c) * crop_size + h + pad_h) + * crop_size + w + pad_w; + // int top_index = (c * height + h) * width + w; + Dtype pixel = static_cast(ptr[img_index++]); + if (this->has_mean_file_) { + int mean_index = (c * mean_height + h + mean_off + pad_h) + * mean_width + w + mean_off + pad_w; + top_data[top_index] = (pixel - mean[mean_index]) * scale; + } else { + if (this->has_mean_values_) { + top_data[top_index] = (pixel - this->mean_values_[c]) * scale; + } else { + top_data[top_index] = pixel * scale; + } + } + } + } + } + trans_time += timer.MicroSeconds(); + // get window label + top_label[item_id] = window[WindowDataLayer::LABEL]; + + #if 0 + // useful debugging code for dumping transformed windows to disk + string file_id; + std::stringstream ss; + ss << PrefetchRand(); + ss >> file_id; + std::ofstream inf((string("dump/") + file_id + + string("_info.txt")).c_str(), std::ofstream::out); + inf << image.first << std::endl + << window[WindowDataLayer::X1]+1 << std::endl + << window[WindowDataLayer::Y1]+1 << std::endl + << window[WindowDataLayer::X2]+1 << std::endl + << window[WindowDataLayer::Y2]+1 << std::endl + << do_mirror << std::endl + << top_label[item_id] << std::endl + << is_fg << std::endl; + inf.close(); + std::ofstream top_data_file((string("dump/") + file_id + + string("_data.txt")).c_str(), + std::ofstream::out | std::ofstream::binary); + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < crop_size; ++h) { + for (int w = 0; w < crop_size; ++w) { + top_data_file.write(reinterpret_cast( + &top_data[((item_id * channels + c) * crop_size + h) + * crop_size + w]), + sizeof(Dtype)); + } + } + } + top_data_file.close(); + #endif + + item_id++; + } + } + batch_timer.Stop(); + DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; + DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; + DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; +} + +INSTANTIATE_CLASS(WindowDataLayer); +REGISTER_LAYER_CLASS(WINDOW_DATA, WindowDataLayer); +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/leveldb_dataset.cpp b/caffe-crfrnn/src/caffe/leveldb_dataset.cpp new file mode 100644 index 00000000..53df9857 --- /dev/null +++ b/caffe-crfrnn/src/caffe/leveldb_dataset.cpp @@ -0,0 +1,265 @@ +#include +#include +#include + +#include "caffe/caffe.hpp" +#include "caffe/leveldb_dataset.hpp" + +namespace caffe { + +template +bool LeveldbDataset::open(const string& filename, + Mode mode) { + DLOG(INFO) << "LevelDB: Open " << filename; + + leveldb::Options options; + switch (mode) { + case Base::New: + DLOG(INFO) << " mode NEW"; + options.error_if_exists = true; + options.create_if_missing = true; + read_only_ = false; + break; + case Base::ReadWrite: + DLOG(INFO) << " mode RW"; + options.error_if_exists = false; + options.create_if_missing = true; + read_only_ = false; + break; + case Base::ReadOnly: + DLOG(INFO) << " mode RO"; + options.error_if_exists = false; + options.create_if_missing = false; + read_only_ = true; + break; + default: + DLOG(FATAL) << "unknown mode " << mode; + } + options.write_buffer_size = 268435456; + options.max_open_files = 100; + + leveldb::DB* db; + + LOG(INFO) << "Opening leveldb " << filename; + leveldb::Status status = leveldb::DB::Open( + options, filename, &db); + db_.reset(db); + + if (!status.ok()) { + LOG(ERROR) << "Failed to open leveldb " << filename + << ". Is it already existing?"; + return false; + } + + batch_.reset(new leveldb::WriteBatch()); + return true; +} + +template +bool LeveldbDataset::put(const K& key, const V& value) { + DLOG(INFO) << "LevelDB: Put"; + + if (read_only_) { + LOG(ERROR) << "put can not be used on a dataset in ReadOnly mode"; + return false; + } + + CHECK_NOTNULL(batch_.get()); + + string serialized_key; + if (!KCoder::serialize(key, &serialized_key)) { + return false; + } + + string serialized_value; + if (!VCoder::serialize(value, &serialized_value)) { + return false; + } + + batch_->Put(serialized_key, serialized_value); + + return true; +} + +template +bool LeveldbDataset::get(const K& key, V* value) { + DLOG(INFO) << "LevelDB: Get"; + + string serialized_key; + if (!KCoder::serialize(key, &serialized_key)) { + return false; + } + + string serialized_value; + leveldb::Status status = + db_->Get(leveldb::ReadOptions(), serialized_key, &serialized_value); + + if (!status.ok()) { + LOG(ERROR) << "leveldb get failed"; + return false; + } + + if (!VCoder::deserialize(serialized_value, value)) { + return false; + } + + return true; +} + +template +bool LeveldbDataset::first_key(K* key) { + DLOG(INFO) << "LevelDB: First key"; + + CHECK_NOTNULL(db_.get()); + shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); + iter->SeekToFirst(); + CHECK(iter->Valid()); + const leveldb::Slice& key_slice = iter->key(); + return KCoder::deserialize(key_slice.data(), key_slice.size(), key); +} + +template +bool LeveldbDataset::last_key(K* key) { + DLOG(INFO) << "LevelDB: Last key"; + + CHECK_NOTNULL(db_.get()); + shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); + iter->SeekToLast(); + CHECK(iter->Valid()); + const leveldb::Slice& key_slice = iter->key(); + return KCoder::deserialize(key_slice.data(), key_slice.size(), key); +} + +template +bool LeveldbDataset::commit() { + DLOG(INFO) << "LevelDB: Commit"; + + if (read_only_) { + LOG(ERROR) << "commit can not be used on a dataset in ReadOnly mode"; + return false; + } + + CHECK_NOTNULL(db_.get()); + CHECK_NOTNULL(batch_.get()); + + leveldb::Status status = db_->Write(leveldb::WriteOptions(), batch_.get()); + + batch_.reset(new leveldb::WriteBatch()); + + return status.ok(); +} + +template +void LeveldbDataset::close() { + DLOG(INFO) << "LevelDB: Close"; + + batch_.reset(); + db_.reset(); +} + +template +void LeveldbDataset::keys(vector* keys) { + DLOG(INFO) << "LevelDB: Keys"; + + keys->clear(); + for (const_iterator iter = begin(); iter != end(); ++iter) { + keys->push_back(iter->key); + } +} + +template +typename LeveldbDataset::const_iterator + LeveldbDataset::begin() const { + CHECK_NOTNULL(db_.get()); + shared_ptr iter(db_->NewIterator(leveldb::ReadOptions())); + iter->SeekToFirst(); + if (!iter->Valid()) { + iter.reset(); + } + + shared_ptr state; + if (iter) { + state.reset(new LeveldbState(db_, iter)); + } + return const_iterator(this, state); +} + +template +typename LeveldbDataset::const_iterator + LeveldbDataset::end() const { + shared_ptr state; + return const_iterator(this, state); +} + +template +typename LeveldbDataset::const_iterator + LeveldbDataset::cbegin() const { + return begin(); +} + +template +typename LeveldbDataset::const_iterator + LeveldbDataset::cend() const { return end(); } + +template +bool LeveldbDataset::equal( + shared_ptr state1, shared_ptr state2) const { + shared_ptr leveldb_state1 = + boost::dynamic_pointer_cast(state1); + + shared_ptr leveldb_state2 = + boost::dynamic_pointer_cast(state2); + + // The KV store doesn't really have any sort of ordering, + // so while we can do a sequential scan over the collection, + // we can't really use subranges. + return !leveldb_state1 && !leveldb_state2; +} + +template +void LeveldbDataset::increment( + shared_ptr* state) const { + shared_ptr leveldb_state = + boost::dynamic_pointer_cast(*state); + + CHECK_NOTNULL(leveldb_state.get()); + + shared_ptr& iter = leveldb_state->iter_; + + CHECK_NOTNULL(iter.get()); + CHECK(iter->Valid()); + + iter->Next(); + if (!iter->Valid()) { + state->reset(); + } +} + +template +typename Dataset::KV& + LeveldbDataset::dereference( + shared_ptr state) const { + shared_ptr leveldb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(leveldb_state.get()); + + shared_ptr& iter = leveldb_state->iter_; + + CHECK_NOTNULL(iter.get()); + + CHECK(iter->Valid()); + + const leveldb::Slice& key = iter->key(); + const leveldb::Slice& value = iter->value(); + CHECK(KCoder::deserialize(key.data(), key.size(), + &leveldb_state->kv_pair_.key)); + CHECK(VCoder::deserialize(value.data(), value.size(), + &leveldb_state->kv_pair_.value)); + + return leveldb_state->kv_pair_; +} + +INSTANTIATE_DATASET(LeveldbDataset); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/lmdb_dataset.cpp b/caffe-crfrnn/src/caffe/lmdb_dataset.cpp new file mode 100644 index 00000000..8f8e68e9 --- /dev/null +++ b/caffe-crfrnn/src/caffe/lmdb_dataset.cpp @@ -0,0 +1,366 @@ +#include + +#include +#include +#include + +#include "caffe/caffe.hpp" +#include "caffe/lmdb_dataset.hpp" + +namespace caffe { + +template +bool LmdbDataset::open(const string& filename, + Mode mode) { + DLOG(INFO) << "LMDB: Open " << filename; + + CHECK(NULL == env_); + CHECK(NULL == write_txn_); + CHECK(NULL == read_txn_); + CHECK_EQ(0, dbi_); + + int retval; + if (mode != Base::ReadOnly) { + retval = mkdir(filename.c_str(), 0744); + switch (mode) { + case Base::New: + if (0 != retval) { + LOG(ERROR) << "mkdir " << filename << " failed"; + return false; + } + break; + case Base::ReadWrite: + if (-1 == retval && EEXIST != errno) { + LOG(ERROR) << "mkdir " << filename << " failed (" + << strerror(errno) << ")"; + return false; + } + break; + default: + LOG(FATAL) << "Invalid mode " << mode; + } + } + + retval = mdb_env_create(&env_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_env_create failed " + << mdb_strerror(retval); + return false; + } + + retval = mdb_env_set_mapsize(env_, 1099511627776); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_env_set_mapsize failed " << mdb_strerror(retval); + return false; + } + + int flag1 = 0; + int flag2 = 0; + if (mode == Base::ReadOnly) { + flag1 = MDB_RDONLY | MDB_NOTLS; + flag2 = MDB_RDONLY; + } + + retval = mdb_env_open(env_, filename.c_str(), flag1, 0664); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_env_open failed " << mdb_strerror(retval); + return false; + } + + retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &read_txn_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); + return false; + } + + retval = mdb_txn_begin(env_, NULL, flag2, &write_txn_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); + return false; + } + + retval = mdb_open(write_txn_, NULL, 0, &dbi_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_open failed" << mdb_strerror(retval); + return false; + } + + return true; +} + +template +bool LmdbDataset::put(const K& key, const V& value) { + DLOG(INFO) << "LMDB: Put"; + + vector serialized_key; + if (!KCoder::serialize(key, &serialized_key)) { + LOG(ERROR) << "failed to serialize key"; + return false; + } + + vector serialized_value; + if (!VCoder::serialize(value, &serialized_value)) { + LOG(ERROR) << "failed to serialized value"; + return false; + } + + MDB_val mdbkey, mdbdata; + mdbdata.mv_size = serialized_value.size(); + mdbdata.mv_data = serialized_value.data(); + mdbkey.mv_size = serialized_key.size(); + mdbkey.mv_data = serialized_key.data(); + + CHECK_NOTNULL(write_txn_); + CHECK_NE(0, dbi_); + + int retval = mdb_put(write_txn_, dbi_, &mdbkey, &mdbdata, 0); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_put failed " << mdb_strerror(retval); + return false; + } + + return true; +} + +template +bool LmdbDataset::get(const K& key, V* value) { + DLOG(INFO) << "LMDB: Get"; + + vector serialized_key; + if (!KCoder::serialize(key, &serialized_key)) { + LOG(ERROR) << "failed to serialized key"; + return false; + } + + MDB_val mdbkey, mdbdata; + mdbkey.mv_data = serialized_key.data(); + mdbkey.mv_size = serialized_key.size(); + + int retval; + retval = mdb_get(read_txn_, dbi_, &mdbkey, &mdbdata); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_get failed " << mdb_strerror(retval); + return false; + } + + if (!VCoder::deserialize(reinterpret_cast(mdbdata.mv_data), + mdbdata.mv_size, value)) { + LOG(ERROR) << "failed to deserialize value"; + return false; + } + + return true; +} + +template +bool LmdbDataset::first_key(K* key) { + DLOG(INFO) << "LMDB: First key"; + + int retval; + + MDB_cursor* cursor; + retval = mdb_cursor_open(read_txn_, dbi_, &cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + MDB_val mdbkey; + MDB_val mdbval; + retval = mdb_cursor_get(cursor, &mdbkey, &mdbval, MDB_FIRST); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + + mdb_cursor_close(cursor); + + if (!KCoder::deserialize(reinterpret_cast(mdbkey.mv_data), + mdbkey.mv_size, key)) { + LOG(ERROR) << "failed to deserialize key"; + return false; + } + + return true; +} + +template +bool LmdbDataset::last_key(K* key) { + DLOG(INFO) << "LMDB: Last key"; + + int retval; + + MDB_cursor* cursor; + retval = mdb_cursor_open(read_txn_, dbi_, &cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + MDB_val mdbkey; + MDB_val mdbval; + retval = mdb_cursor_get(cursor, &mdbkey, &mdbval, MDB_LAST); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + + mdb_cursor_close(cursor); + + if (!KCoder::deserialize(reinterpret_cast(mdbkey.mv_data), + mdbkey.mv_size, key)) { + LOG(ERROR) << "failed to deserialize key"; + return false; + } + + return true; +} + +template +bool LmdbDataset::commit() { + DLOG(INFO) << "LMDB: Commit"; + + CHECK_NOTNULL(write_txn_); + + int retval; + retval = mdb_txn_commit(write_txn_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_txn_commit failed " << mdb_strerror(retval); + return false; + } + + mdb_txn_abort(read_txn_); + + retval = mdb_txn_begin(env_, NULL, 0, &write_txn_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); + return false; + } + + retval = mdb_txn_begin(env_, NULL, MDB_RDONLY, &read_txn_); + if (MDB_SUCCESS != retval) { + LOG(ERROR) << "mdb_txn_begin failed " << mdb_strerror(retval); + return false; + } + + return true; +} + +template +void LmdbDataset::close() { + DLOG(INFO) << "LMDB: Close"; + + if (env_ && dbi_) { + mdb_txn_abort(write_txn_); + mdb_txn_abort(read_txn_); + mdb_close(env_, dbi_); + mdb_env_close(env_); + env_ = NULL; + dbi_ = 0; + write_txn_ = NULL; + read_txn_ = NULL; + } +} + +template +void LmdbDataset::keys(vector* keys) { + DLOG(INFO) << "LMDB: Keys"; + + keys->clear(); + for (const_iterator iter = begin(); iter != end(); ++iter) { + keys->push_back(iter->key); + } +} + +template +typename LmdbDataset::const_iterator + LmdbDataset::begin() const { + int retval; + + MDB_cursor* cursor; + retval = mdb_cursor_open(read_txn_, dbi_, &cursor); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + MDB_val key; + MDB_val val; + retval = mdb_cursor_get(cursor, &key, &val, MDB_FIRST); + + CHECK(MDB_SUCCESS == retval || MDB_NOTFOUND == retval) + << mdb_strerror(retval); + + shared_ptr state; + if (MDB_SUCCESS == retval) { + state.reset(new LmdbState(cursor, read_txn_, &dbi_)); + } else { + mdb_cursor_close(cursor); + } + return const_iterator(this, state); +} + +template +typename LmdbDataset::const_iterator + LmdbDataset::end() const { + shared_ptr state; + return const_iterator(this, state); +} + +template +typename LmdbDataset::const_iterator + LmdbDataset::cbegin() const { return begin(); } + +template +typename LmdbDataset::const_iterator + LmdbDataset::cend() const { return end(); } + +template +bool LmdbDataset::equal(shared_ptr state1, + shared_ptr state2) const { + shared_ptr lmdb_state1 = + boost::dynamic_pointer_cast(state1); + + shared_ptr lmdb_state2 = + boost::dynamic_pointer_cast(state2); + + // The KV store doesn't really have any sort of ordering, + // so while we can do a sequential scan over the collection, + // we can't really use subranges. + return !lmdb_state1 && !lmdb_state2; +} + +template +void LmdbDataset::increment( + shared_ptr* state) const { + shared_ptr lmdb_state = + boost::dynamic_pointer_cast(*state); + + CHECK_NOTNULL(lmdb_state.get()); + + MDB_cursor*& cursor = lmdb_state->cursor_; + + CHECK_NOTNULL(cursor); + + MDB_val key; + MDB_val val; + int retval = mdb_cursor_get(cursor, &key, &val, MDB_NEXT); + if (MDB_NOTFOUND == retval) { + mdb_cursor_close(cursor); + state->reset(); + } else { + CHECK_EQ(MDB_SUCCESS, retval) << mdb_strerror(retval); + } +} + +template +typename Dataset::KV& + LmdbDataset::dereference( + shared_ptr state) const { + shared_ptr lmdb_state = + boost::dynamic_pointer_cast(state); + + CHECK_NOTNULL(lmdb_state.get()); + + MDB_cursor*& cursor = lmdb_state->cursor_; + + CHECK_NOTNULL(cursor); + + MDB_val mdb_key; + MDB_val mdb_val; + int retval = mdb_cursor_get(cursor, &mdb_key, &mdb_val, MDB_GET_CURRENT); + CHECK_EQ(retval, MDB_SUCCESS) << mdb_strerror(retval); + + CHECK(KCoder::deserialize(reinterpret_cast(mdb_key.mv_data), + mdb_key.mv_size, &lmdb_state->kv_pair_.key)); + CHECK(VCoder::deserialize(reinterpret_cast(mdb_val.mv_data), + mdb_val.mv_size, &lmdb_state->kv_pair_.value)); + + return lmdb_state->kv_pair_; +} + +INSTANTIATE_DATASET(LmdbDataset); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/net.cpp b/caffe-crfrnn/src/caffe/net.cpp new file mode 100644 index 00000000..67f8a22b --- /dev/null +++ b/caffe-crfrnn/src/caffe/net.cpp @@ -0,0 +1,826 @@ +#include +#include +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/layer.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/insert_splits.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/upgrade_proto.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +Net::Net(const NetParameter& param) { + Init(param); +} + +template +Net::Net(const string& param_file) { + NetParameter param; + ReadNetParamsFromTextFileOrDie(param_file, ¶m); + Init(param); +} + +template +void Net::Init(const NetParameter& in_param) { + // Filter layers based on their include/exclude rules and + // the current NetState. + NetParameter filtered_param; + FilterNet(in_param, &filtered_param); + LOG(INFO) << "Initializing net from parameters: " << std::endl + << filtered_param.DebugString(); + // Create a copy of filtered_param with splits added where necessary. + NetParameter param; + InsertSplits(filtered_param, ¶m); + // Basically, build all the layers and set up its connections. + name_ = param.name(); + map blob_name_to_idx; + set available_blobs; + CHECK_EQ(param.input_size() * 4, param.input_dim_size()) + << "Incorrect input blob dimension specifications."; + memory_used_ = 0; + // set the input blobs + for (int input_id = 0; input_id < param.input_size(); ++input_id) { + const int layer_id = -1; // inputs have fake layer ID -1 + AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); + } + DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + // For each layer, set up their input and output + bottom_vecs_.resize(param.layers_size()); + top_vecs_.resize(param.layers_size()); + bottom_id_vecs_.resize(param.layers_size()); + top_id_vecs_.resize(param.layers_size()); + bottom_need_backward_.resize(param.layers_size()); + for (int layer_id = 0; layer_id < param.layers_size(); ++layer_id) { + const LayerParameter& layer_param = param.layers(layer_id); + layers_.push_back(shared_ptr >( + LayerRegistry::CreateLayer(layer_param))); + layers_[layer_id]->set_net(this); + layer_names_.push_back(layer_param.name()); + LOG(INFO) << "Creating Layer " << layer_param.name(); + bool need_backward = false; + // Figure out this layer's input and output + for (int bottom_id = 0; bottom_id < layer_param.bottom_size(); + ++bottom_id) { + const int blob_id = AppendBottom(param, layer_id, bottom_id, + &available_blobs, &blob_name_to_idx); + // If a blob needs backward, this layer should provide it. + need_backward |= blob_need_backward_[blob_id]; + } + int num_top = layer_param.top_size(); + for (int top_id = 0; top_id < num_top; ++top_id) { + AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx); + } + // If the layer specifies that AutoTopBlobs() -> true and the LayerParameter + // specified fewer than the required number (as specified by + // ExactNumTopBlobs() or MinTopBlobs()), allocate them here. + Layer* layer = layers_[layer_id].get(); + if (layer->AutoTopBlobs()) { + const int needed_num_top = + std::max(layer->MinTopBlobs(), layer->ExactNumTopBlobs()); + for (; num_top < needed_num_top; ++num_top) { + // Add "anonymous" top blobs -- do not modify available_blobs or + // blob_name_to_idx as we don't want these blobs to be usable as input + // to other layers. + AppendTop(param, layer_id, num_top, NULL, NULL); + } + } + // After this layer is connected, set it up. + LOG(INFO) << "Setting up " << layer_names_[layer_id]; + layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); + for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { + if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { + blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); + } + blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); + LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->num() << " " + << top_vecs_[layer_id][top_id]->channels() << " " + << top_vecs_[layer_id][top_id]->height() << " " + << top_vecs_[layer_id][top_id]->width() << " (" + << top_vecs_[layer_id][top_id]->count() << ")"; + if (layer->loss(top_id)) { + LOG(INFO) << " with loss weight " << layer->loss(top_id); + } + memory_used_ += top_vecs_[layer_id][top_id]->count(); + } + DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + const int blobs_lr_size = layer_param.blobs_lr_size(); + const int num_param_blobs = layers_[layer_id]->blobs().size(); + CHECK(blobs_lr_size == num_param_blobs || blobs_lr_size == 0) + << "Incorrect blobs lr size: should be either 0 " + << "or the same as the number of the layer's parameter blobs."; + if (blobs_lr_size) { + // Check if this layer needs backward operation itself + for (int param_id = 0; param_id < blobs_lr_size; ++param_id) { + const bool param_need_backward = layer_param.blobs_lr(param_id) > 0; + need_backward |= param_need_backward; + layers_[layer_id]->set_param_propagate_down(param_id, + param_need_backward); + } + } else if (layers_[layer_id]->blobs().size()) { + // catch: if a layer param does not specify blobs_lr, we should assume the + // learning rate to be 1. Thus we will need to perform backward. + need_backward = true; + for (int param_id = 0; param_id < blobs_lr_size; ++param_id) { + layers_[layer_id]->set_param_propagate_down(param_id, true); + } + } + const int param_size = layer_param.param_size(); + CHECK(param_size == num_param_blobs || param_size == 0) + << "Incorrect param size: should be either 0 or the same as " + "the number of the layer's parameter blobs: " << num_param_blobs; + const int blob_share_mode_size = layer_param.blob_share_mode_size(); + CHECK(blob_share_mode_size == num_param_blobs || blob_share_mode_size == 0) + << "Incorrect blob_share_mode size: should be either 0 or the same as " + "the number of the layer's parameter blobs: " << num_param_blobs; + for (int param_id = 0; param_id < num_param_blobs; ++param_id) { + AppendParam(param, layer_id, param_id); + } + // Finally, set the backward flag + layer_need_backward_.push_back(need_backward); + if (need_backward) { + for (int top_id = 0; top_id < top_id_vecs_[layer_id].size(); ++top_id) { + blob_need_backward_[top_id_vecs_[layer_id][top_id]] = true; + } + } + } + // Go through the net backwards to determine which blobs contribute to the + // loss. We can skip backward computation for blobs that don't contribute + // to the loss. + set blobs_under_loss; + for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) { + bool layer_contributes_loss = false; + for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { + const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; + if (layers_[layer_id]->loss(top_id) || + (blobs_under_loss.find(blob_name) != blobs_under_loss.end())) { + layer_contributes_loss = true; + break; + } + } + if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; } + if (layer_need_backward_[layer_id]) { + LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + } else { + LOG(INFO) << layer_names_[layer_id] + << " does not need backward computation."; + } + for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size(); + ++bottom_id) { + if (layer_contributes_loss) { + const string& blob_name = + blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; + blobs_under_loss.insert(blob_name); + } else { + bottom_need_backward_[layer_id][bottom_id] = false; + } + } + } + // Handle force_backward if needed. + if (param.force_backward()) { + for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) { + layer_need_backward_[layer_id] = true; + for (int bottom_id = 0; + bottom_id < bottom_need_backward_[layer_id].size(); ++bottom_id) { + bottom_need_backward_[layer_id][bottom_id] = + bottom_need_backward_[layer_id][bottom_id] || + layers_[layer_id]->AllowForceBackward(bottom_id); + blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] = + blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] || + bottom_need_backward_[layer_id][bottom_id]; + } + for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); + ++param_id) { + layers_[layer_id]->set_param_propagate_down(param_id, true); + } + } + } + // In the end, all remaining blobs are considered output blobs. + for (set::iterator it = available_blobs.begin(); + it != available_blobs.end(); ++it) { + LOG(INFO) << "This network produces output " << *it; + net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); + net_output_blob_indices_.push_back(blob_name_to_idx[*it]); + } + for (size_t blob_id = 0; blob_id < blob_names_.size(); ++blob_id) { + blob_names_index_[blob_names_[blob_id]] = blob_id; + } + for (size_t layer_id = 0; layer_id < layer_names_.size(); ++layer_id) { + layer_names_index_[layer_names_[layer_id]] = layer_id; + } + GetLearningRateAndWeightDecay(); + LOG(INFO) << "Network initialization done."; + LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + // Don't display debug info by default. + debug_info_ = false; +} + +template +void Net::FilterNet(const NetParameter& param, + NetParameter* param_filtered) { + NetState net_state(param.state()); + // Let the phase of the net be the current global phase provided in the Caffe + // singleton, unless explicitly provided by the state. + if (!net_state.has_phase()) { + switch (Caffe::phase()) { + case Caffe::TRAIN: + net_state.set_phase(TRAIN); + break; + case Caffe::TEST: + net_state.set_phase(TEST); + break; + default: + LOG(FATAL) << "Unknown phase: " << Caffe::phase(); + } + } + param_filtered->CopyFrom(param); + param_filtered->clear_layers(); + for (int i = 0; i < param.layers_size(); ++i) { + const LayerParameter& layer_param = param.layers(i); + const string& layer_name = layer_param.name(); + CHECK(layer_param.include_size() == 0 || layer_param.exclude_size() == 0) + << "Specify either include rules or exclude rules; not both."; + // If no include rules are specified, the layer is included by default and + // only excluded if it meets one of the exclude rules. + bool layer_included = (layer_param.include_size() == 0); + for (int j = 0; layer_included && j < layer_param.exclude_size(); ++j) { + if (StateMeetsRule(net_state, layer_param.exclude(j), layer_name)) { + layer_included = false; + } + } + for (int j = 0; !layer_included && j < layer_param.include_size(); ++j) { + if (StateMeetsRule(net_state, layer_param.include(j), layer_name)) { + layer_included = true; + } + } + if (layer_included) { + param_filtered->add_layers()->CopyFrom(layer_param); + } + } +} + +template +bool Net::StateMeetsRule(const NetState& state, + const NetStateRule& rule, const string& layer_name) { + // Check whether the rule is broken due to phase. + if (rule.has_phase()) { + if (rule.phase() != state.phase()) { + LOG(INFO) << "The NetState phase (" << state.phase() + << ") differed from the phase (" << rule.phase() + << ") specified by a rule in layer " << layer_name; + return false; + } + } + // Check whether the rule is broken due to min level. + if (rule.has_min_level()) { + if (state.level() < rule.min_level()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the min_level (" << rule.min_level() + << ") specified by a rule in layer " << layer_name; + return false; + } + } + // Check whether the rule is broken due to max level. + if (rule.has_max_level()) { + if (state.level() > rule.max_level()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the max_level (" << rule.max_level() + << ") specified by a rule in layer " << layer_name; + return false; + } + } + // Check whether the rule is broken due to stage. The NetState must + // contain ALL of the rule's stages to meet it. + for (int i = 0; i < rule.stage_size(); ++i) { + // Check that the NetState contains the rule's ith stage. + bool has_stage = false; + for (int j = 0; !has_stage && j < state.stage_size(); ++j) { + if (rule.stage(i) == state.stage(j)) { has_stage = true; } + } + if (!has_stage) { + LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) + << "' specified by a rule in layer " << layer_name; + return false; + } + } + // Check whether the rule is broken due to not_stage. The NetState must + // contain NONE of the rule's not_stages to meet it. + for (int i = 0; i < rule.not_stage_size(); ++i) { + // Check that the NetState contains the rule's ith not_stage. + bool has_stage = false; + for (int j = 0; !has_stage && j < state.stage_size(); ++j) { + if (rule.not_stage(i) == state.stage(j)) { has_stage = true; } + } + if (has_stage) { + LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) + << "' specified by a rule in layer " << layer_name; + return false; + } + } + return true; +} + +// Helper for Net::Init: add a new input or top blob to the net. (Inputs have +// layer_id == -1, tops have layer_id >= 0.) +template +void Net::AppendTop(const NetParameter& param, const int layer_id, + const int top_id, set* available_blobs, + map* blob_name_to_idx) { + shared_ptr layer_param((layer_id >= 0) ? + (new LayerParameter(param.layers(layer_id))) : NULL); + const string& blob_name = layer_param ? + (layer_param->top_size() > top_id ? + layer_param->top(top_id) : "(automatic)") : param.input(top_id); + // Check if we are doing in-place computation + if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && + blob_name == layer_param->bottom(top_id)) { + // In-place computation + LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get()); + top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]); + } else if (blob_name_to_idx && + blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) { + // If we are not doing in-place computation but have duplicated blobs, + // raise an error. + LOG(FATAL) << "Duplicate blobs produced by multiple sources."; + } else { + // Normal output. + if (layer_param) { + LOG(INFO) << layer_param->name() << " -> " << blob_name; + } else { + LOG(INFO) << "Input " << top_id << " -> " << blob_name; + } + shared_ptr > blob_pointer(new Blob()); + const int blob_id = blobs_.size(); + blobs_.push_back(blob_pointer); + blob_names_.push_back(blob_name); + blob_need_backward_.push_back(false); + if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; } + if (layer_id == -1) { + // Set the (explicitly specified) dimensions of the input blob. + blob_pointer->Reshape(param.input_dim(top_id * 4), + param.input_dim(top_id * 4 + 1), + param.input_dim(top_id * 4 + 2), + param.input_dim(top_id * 4 + 3)); + net_input_blob_indices_.push_back(blob_id); + net_input_blobs_.push_back(blob_pointer.get()); + } else { + top_id_vecs_[layer_id].push_back(blob_id); + top_vecs_[layer_id].push_back(blob_pointer.get()); + } + } + if (available_blobs) { available_blobs->insert(blob_name); } +} + +// Helper for Net::Init: add a new bottom blob to the net. +template +int Net::AppendBottom(const NetParameter& param, + const int layer_id, const int bottom_id, + set* available_blobs, map* blob_name_to_idx) { + const LayerParameter& layer_param = param.layers(layer_id); + const string& blob_name = layer_param.bottom(bottom_id); + if (available_blobs->find(blob_name) == available_blobs->end()) { + LOG(FATAL) << "Unknown blob input " << blob_name + << " (at index " << bottom_id << ") to layer " << layer_id; + } + const int blob_id = (*blob_name_to_idx)[blob_name]; + LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); + bottom_id_vecs_[layer_id].push_back(blob_id); + available_blobs->erase(blob_name); + const bool need_backward = blob_need_backward_[blob_id]; + bottom_need_backward_[layer_id].push_back(need_backward); + return blob_id; +} + +template +void Net::AppendParam(const NetParameter& param, const int layer_id, + const int param_id) { + const LayerParameter& layer_param = layers_[layer_id]->layer_param(); + const int param_size = layer_param.param_size(); + string param_name = param_size ? layer_param.param(param_id) : ""; + if (param_name.size()) { + param_display_names_.push_back(param_name); + } else { + ostringstream param_display_name; + param_display_name << param_id; + param_display_names_.push_back(param_display_name.str()); + } + const int net_param_id = params_.size(); + params_.push_back(layers_[layer_id]->blobs()[param_id]); + param_layer_indices_.push_back(make_pair(layer_id, param_id)); + if (!param_size || !param_name.size() || (param_name.size() && + param_names_index_.find(param_name) == param_names_index_.end())) { + // This layer "owns" this parameter blob -- it is either anonymous + // (i.e., not given a param_name) or explicitly given a name that we + // haven't already seen. + param_owners_.push_back(-1); + if (param_size) { + param_names_index_[param_name] = net_param_id; + } + } else { + // Named param blob with name we've seen before: share params + const int owner_net_param_id = param_names_index_[param_name]; + param_owners_.push_back(owner_net_param_id); + const pair& owner_index = + param_layer_indices_[owner_net_param_id]; + const int owner_layer_id = owner_index.first; + const int owner_param_id = owner_index.second; + LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " + << "layer '" << layer_names_[owner_layer_id] << "', param " + << "index " << owner_param_id; + Blob* this_blob = layers_[layer_id]->blobs()[param_id].get(); + Blob* owner_blob = + layers_[owner_layer_id]->blobs()[owner_param_id].get(); + const int blob_share_mode_size = layer_param.blob_share_mode_size(); + if (blob_share_mode_size > param_id && + (layer_param.blob_share_mode(param_id) == + LayerParameter_DimCheckMode_PERMISSIVE)) { + // Permissive dimension checking -- only check counts are the same. + CHECK_EQ(this_blob->count(), owner_blob->count()) + << "Shared parameter blobs must have the same count."; + } else { + // Strict dimension checking -- all dims must be the same. + CHECK_EQ(this_blob->num(), owner_blob->num()) + << "Shared parameter blobs must have the same num."; + CHECK_EQ(this_blob->channels(), owner_blob->channels()) + << "Shared parameter blobs must have the same channels."; + CHECK_EQ(this_blob->height(), owner_blob->height()) + << "Shared parameter blobs must have the same height."; + CHECK_EQ(this_blob->width(), owner_blob->width()) + << "Shared parameter blobs must have the same width."; + } + layers_[layer_id]->blobs()[param_id]->ShareData( + *layers_[owner_layer_id]->blobs()[owner_param_id]); + } +} + +template +void Net::GetLearningRateAndWeightDecay() { + LOG(INFO) << "Collecting Learning Rate and Weight Decay."; + for (int i = 0; i < layers_.size(); ++i) { + vector > >& layer_blobs = layers_[i]->blobs(); + // push the learning rate mutlipliers + if (layers_[i]->layer_param().blobs_lr_size()) { + CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size()); + for (int j = 0; j < layer_blobs.size(); ++j) { + float local_lr = layers_[i]->layer_param().blobs_lr(j); + CHECK_GE(local_lr, 0.); + params_lr_.push_back(local_lr); + } + } else { + for (int j = 0; j < layer_blobs.size(); ++j) { + params_lr_.push_back(1.); + } + } + // push the weight decay multipliers + if (layers_[i]->layer_param().weight_decay_size()) { + CHECK_EQ(layers_[i]->layer_param().weight_decay_size(), + layer_blobs.size()); + for (int j = 0; j < layer_blobs.size(); ++j) { + float local_decay = layers_[i]->layer_param().weight_decay(j); + CHECK_GE(local_decay, 0.); + params_weight_decay_.push_back(local_decay); + } + } else { + for (int j = 0; j < layer_blobs.size(); ++j) { + params_weight_decay_.push_back(1.); + } + } + } +} + +template +Dtype Net::ForwardFromTo(int start, int end) { + CHECK_GE(start, 0); + CHECK_LT(end, layers_.size()); + Dtype loss = 0; + for (int i = start; i <= end; ++i) { + // LOG(ERROR) << "Forwarding " << layer_names_[i]; + layers_[i]->Reshape(bottom_vecs_[i], top_vecs_[i]); + Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]); + loss += layer_loss; + if (debug_info_) { ForwardDebugInfo(i); } + } + return loss; +} + +template +Dtype Net::ForwardFrom(int start) { + return ForwardFromTo(start, layers_.size() - 1); +} + +template +Dtype Net::ForwardTo(int end) { + return ForwardFromTo(0, end); +} + +template +const vector*>& Net::ForwardPrefilled(Dtype* loss) { + if (loss != NULL) { + *loss = ForwardFromTo(0, layers_.size() - 1); + } else { + ForwardFromTo(0, layers_.size() - 1); + } + return net_output_blobs_; +} + +template +const vector*>& Net::Forward( + const vector*> & bottom, Dtype* loss) { + // Copy bottom to internal bottom + for (int i = 0; i < bottom.size(); ++i) { + net_input_blobs_[i]->CopyFrom(*bottom[i]); + } + return ForwardPrefilled(loss); +} + +template +string Net::Forward(const string& input_blob_protos, Dtype* loss) { + BlobProtoVector blob_proto_vec; + if (net_input_blobs_.size()) { + blob_proto_vec.ParseFromString(input_blob_protos); + CHECK_EQ(blob_proto_vec.blobs_size(), net_input_blobs_.size()) + << "Incorrect input size."; + for (int i = 0; i < blob_proto_vec.blobs_size(); ++i) { + net_input_blobs_[i]->FromProto(blob_proto_vec.blobs(i)); + } + } + ForwardPrefilled(loss); + blob_proto_vec.Clear(); + for (int i = 0; i < net_output_blobs_.size(); ++i) { + net_output_blobs_[i]->ToProto(blob_proto_vec.add_blobs()); + } + string output; + blob_proto_vec.SerializeToString(&output); + return output; +} + +template +void Net::BackwardFromTo(int start, int end) { + CHECK_GE(end, 0); + CHECK_LT(start, layers_.size()); + for (int i = start; i >= end; --i) { + if (layer_need_backward_[i]) { + layers_[i]->Backward( + top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]); + if (debug_info_) { BackwardDebugInfo(i); } + } + } +} + +template +void Net::ForwardDebugInfo(const int layer_id) { + for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { + const Blob& blob = *top_vecs_[layer_id][top_id]; + const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; + const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name + << " data: " << data_abs_val_mean; + } +} + +template +void Net::BackwardDebugInfo(const int layer_id) { + const vector*>& bottom_vec = bottom_vecs_[layer_id]; + for (int bottom_id = 0; bottom_id < bottom_vec.size(); ++bottom_id) { + if (!bottom_need_backward_[layer_id][bottom_id]) { continue; } + const Blob& blob = *bottom_vec[bottom_id]; + const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; + const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name + << " diff: " << diff_abs_val_mean; + } + for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); + ++param_id) { + if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; } + const Blob& blob = *layers_[layer_id]->blobs()[param_id]; + const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] << ", param blob " << param_id + << " diff: " << diff_abs_val_mean; + } +} + +template +void Net::UpdateDebugInfo(const int param_id) { + const Blob& blob = *params_[param_id]; + const int param_owner = param_owners_[param_id]; + const string& layer_name = layer_names_[param_layer_indices_[param_id].first]; + const string& param_display_name = param_display_names_[param_id]; + const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); + if (param_owner < 0) { + const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); + LOG(INFO) << " [Update] Layer " << layer_name + << ", param " << param_display_name + << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean; + } else { + const string& owner_layer_name = + layer_names_[param_layer_indices_[param_owner].first]; + LOG(INFO) << " [Update] Layer " << layer_name + << ", param blob " << param_display_name + << " (owned by layer " << owner_layer_name << ", " + << "param " << param_display_names_[param_owners_[param_id]] << ")" + << " diff: " << diff_abs_val_mean; + } +} + +template +void Net::ShareTrainedLayersWith(Net* other) { + int num_source_layers = other->layers().size(); + for (int i = 0; i < num_source_layers; ++i) { + Layer* source_layer = other->layers()[i].get(); + const string& source_layer_name = other->layer_names()[i]; + int target_layer_id = 0; + while (target_layer_id != layer_names_.size() && + layer_names_[target_layer_id] != source_layer_name) { + ++target_layer_id; + } + if (target_layer_id == layer_names_.size()) { + DLOG(INFO) << "Ignoring source layer " << source_layer_name; + continue; + } + DLOG(INFO) << "Copying source layer " << source_layer_name; + vector > >& target_blobs = + layers_[target_layer_id]->blobs(); + CHECK_EQ(target_blobs.size(), source_layer->blobs().size()) + << "Incompatible number of blobs for layer " << source_layer_name; + for (int j = 0; j < target_blobs.size(); ++j) { + Blob* source_blob = source_layer->blobs()[j].get(); + CHECK_EQ(target_blobs[j]->num(), source_blob->num()); + CHECK_EQ(target_blobs[j]->channels(), source_blob->channels()); + CHECK_EQ(target_blobs[j]->height(), source_blob->height()); + CHECK_EQ(target_blobs[j]->width(), source_blob->width()); + target_blobs[j]->ShareData(*source_blob); + } + } +} + +template +void Net::BackwardFrom(int start) { + BackwardFromTo(start, 0); +} + +template +void Net::BackwardTo(int end) { + BackwardFromTo(layers_.size() - 1, end); +} + +template +void Net::Backward() { + BackwardFromTo(layers_.size() - 1, 0); +} + +template +void Net::Reshape() { + for (int i = 0; i < layers_.size(); ++i) { + layers_[i]->Reshape(bottom_vecs_[i], top_vecs_[i]); + } +} + +template +void Net::CopyTrainedLayersFrom(const NetParameter& param) { + int num_source_layers = param.layers_size(); + for (int i = 0; i < num_source_layers; ++i) { + const LayerParameter& source_layer = param.layers(i); + const string& source_layer_name = source_layer.name(); + int target_layer_id = 0; + while (target_layer_id != layer_names_.size() && + layer_names_[target_layer_id] != source_layer_name) { + ++target_layer_id; + } + if (target_layer_id == layer_names_.size()) { + DLOG(INFO) << "Ignoring source layer " << source_layer_name; + continue; + } + DLOG(INFO) << "Copying source layer " << source_layer_name; + vector > >& target_blobs = + layers_[target_layer_id]->blobs(); + CHECK_EQ(target_blobs.size(), source_layer.blobs_size()) + << "Incompatible number of blobs for layer " << source_layer_name; + for (int j = 0; j < target_blobs.size(); ++j) { + CHECK_EQ(target_blobs[j]->num(), source_layer.blobs(j).num()); + CHECK_EQ(target_blobs[j]->channels(), source_layer.blobs(j).channels()); + CHECK_EQ(target_blobs[j]->height(), source_layer.blobs(j).height()); + CHECK_EQ(target_blobs[j]->width(), source_layer.blobs(j).width()); + target_blobs[j]->FromProto(source_layer.blobs(j)); + } + } +} + +template +void Net::CopyTrainedLayersFrom(const string trained_filename) { + NetParameter param; + ReadNetParamsFromBinaryFileOrDie(trained_filename, ¶m); + CopyTrainedLayersFrom(param); +} + +template +void Net::ToProto(NetParameter* param, bool write_diff) { + param->Clear(); + param->set_name(name_); + // Add bottom and top + for (int i = 0; i < net_input_blob_indices_.size(); ++i) { + param->add_input(blob_names_[net_input_blob_indices_[i]]); + } + DLOG(INFO) << "Serializing " << layers_.size() << " layers"; + for (int i = 0; i < layers_.size(); ++i) { + LayerParameter* layer_param = param->add_layers(); + for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) { + layer_param->add_bottom(blob_names_[bottom_id_vecs_[i][j]]); + } + for (int j = 0; j < top_id_vecs_[i].size(); ++j) { + layer_param->add_top(blob_names_[top_id_vecs_[i][j]]); + } + layers_[i]->ToProto(layer_param, write_diff); + } +} + +template +void Net::Update() { + // First, accumulate the diffs of any shared parameters into their owner's + // diff. (Assumes that the learning rate, weight decay, etc. have already been + // accounted for in the current diff.) + for (int i = 0; i < params_.size(); ++i) { + if (param_owners_[i] < 0) { continue; } + if (debug_info_) { UpdateDebugInfo(i); } + const int count = params_[i]->count(); + const Dtype* this_diff; + Dtype* owner_diff; + switch (Caffe::mode()) { + case Caffe::CPU: + this_diff = params_[i]->cpu_diff(); + owner_diff = params_[param_owners_[i]]->mutable_cpu_diff(); + caffe_add(count, this_diff, owner_diff, owner_diff); + break; +#ifndef CPU_ONLY + case Caffe::GPU: + this_diff = params_[i]->gpu_diff(); + owner_diff = params_[param_owners_[i]]->mutable_gpu_diff(); + caffe_gpu_add(count, this_diff, owner_diff, owner_diff); + break; +#else + NO_GPU; +#endif + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } + } + // Now, update the owned parameters. + for (int i = 0; i < params_.size(); ++i) { + if (param_owners_[i] >= 0) { continue; } + if (debug_info_) { UpdateDebugInfo(i); } + params_[i]->Update(); + } +} + +template +bool Net::has_blob(const string& blob_name) { + return blob_names_index_.find(blob_name) != blob_names_index_.end(); +} + +template +const shared_ptr > Net::blob_by_name( + const string& blob_name) { + shared_ptr > blob_ptr; + if (has_blob(blob_name)) { + blob_ptr = blobs_[blob_names_index_[blob_name]]; + } else { + blob_ptr.reset((Blob*)(NULL)); + LOG(WARNING) << "Unknown blob name " << blob_name; + } + return blob_ptr; +} + +template +bool Net::has_layer(const string& layer_name) { + return layer_names_index_.find(layer_name) != layer_names_index_.end(); +} + +template +const shared_ptr > Net::layer_by_name( + const string& layer_name) { + shared_ptr > layer_ptr; + if (has_layer(layer_name)) { + layer_ptr = layers_[layer_names_index_[layer_name]]; + } else { + layer_ptr.reset((Layer*)(NULL)); + LOG(WARNING) << "Unknown layer name " << layer_name; + } + return layer_ptr; +} + +INSTANTIATE_CLASS(Net); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/proto/CMakeLists.txt b/caffe-crfrnn/src/caffe/proto/CMakeLists.txt new file mode 100644 index 00000000..12e7ce0a --- /dev/null +++ b/caffe-crfrnn/src/caffe/proto/CMakeLists.txt @@ -0,0 +1,46 @@ +project( Proto ) + +# Google Protocol Buffers +find_package( Protobuf REQUIRED ) + +# As of Ubuntu 14.04 protoc is no longer a part of libprotobuf-dev package and should be installed +# separately as in: sudo apt-get install protobuf-compiler +if(PROTOBUF_PROTOC_EXECUTABLE) + message(STATUS "Found PROTOBUF Compiler: ${PROTOBUF_PROTOC_EXECUTABLE}") +else() + message(FATAL_ERROR "Could not find PROTOBUF Compiler") +endif() + +include_directories(${PROTOBUF_INCLUDE_DIR}) +file(GLOB ProtoFiles "${CMAKE_CURRENT_SOURCE_DIR}/*.proto") +PROTOBUF_GENERATE_CPP(ProtoSources ProtoHeaders ${ProtoFiles}) +PROTOBUF_GENERATE_PYTHON(ProtoSourcesPy ${ProtoFiles}) + +add_custom_target(protoPy DEPENDS ${ProtoSourcesPy}) + +add_library(proto + ${ProtoSources} + ${ProtoHeaders} + ) + + +target_link_libraries(proto ${PROTOBUF_LIBRARIES}) + +# Create proto include directory +file(MAKE_DIRECTORY ${CMAKE_SOURCE_DIR}/include/caffe/proto) + +# Copy proto headers to include/caffe/proto/ +foreach(header ${ProtoHeaders}) + + ADD_CUSTOM_COMMAND(TARGET proto + COMMAND cmake -E copy ${header} + ${Caffe_INCLUDE_DIRS}/caffe/proto/ + DEPENDS ${header} +) + +endforeach(header) + +file(WRITE __init__.py) +install(PROGRAMS __init__.py DESTINATION python/caffe/proto) +install(PROGRAMS ${ProtoSourcesPy} DESTINATION python/caffe/proto) + diff --git a/caffe-crfrnn/src/caffe/proto/caffe.proto b/caffe-crfrnn/src/caffe/proto/caffe.proto new file mode 100644 index 00000000..1e407cba --- /dev/null +++ b/caffe-crfrnn/src/caffe/proto/caffe.proto @@ -0,0 +1,862 @@ +syntax = "proto2"; + +package caffe; + +message BlobProto { + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero input weights for a given output in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + repeated LayerParameter layers = 2; // a bunch of layers. + // The input blobs to the network. + repeated string input = 3; + // The dim of the input blobs. For each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 36 (last added: iter_size) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + optional int32 iter_size = 35 [default = 1]; + optional string lr_policy = 8; // The learning rate decay policy. + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // Solver type + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + } + optional SolverType solver_type = 30 [default = SGD]; + // numerical stability for AdaGrad + optional float delta = 31 [default = 1e-8]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available ID: 43 (last added: MULTII_STAGE_MEANFIELD) +message LayerParameter { + repeated string bottom = 2; // the name of the bottom blobs + repeated string top = 3; // the name of the top blobs + optional string name = 4; // the layer name + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + + // NOTE + // Add new LayerTypes to the enum below in lexicographical order (other than + // starting with NONE), starting with the next available ID in the comment + // line above the enum. Update the next available ID when you add a new + // LayerType. + // + // LayerType next available ID: 42 (last added: MULTI_STAGE_MEANFIELD) + enum LayerType { + // "NONE" layer type is 0th enum element so that we don't cause confusion + // by defaulting to an existent LayerType (instead, should usually error if + // the type is unspecified). + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + CROP = 40; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTI_STAGE_MEANFIELD = 42; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SIMPLE_FAST_MEANFIELD=41; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + } + optional LayerType type = 5; // the layer type from the enum above + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 6; + // The names of the parameter blobs -- useful for sharing parameters among + // layers (but never required). + repeated string param = 1001; + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 7; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 8; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 35; + + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MultiStageMeanfieldParameter multi_stage_meanfield_param = 44; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 36; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 42; + + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + + // DEPRECATED: The layer parameters specified as a V0LayerParameter. + // This should never be used by any code except to upgrade to the new + // LayerParameter specification. + optional V0LayerParameter layer = 1; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // If true, normalize each batch across all instances (including spatial + // dimesions, but not ignored instances); else, divide by batch size only. + optional bool normalize = 2 [default = true]; +} + +// Message that stores parameters used by AccuracyLayer +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; +} + +// Message that stores parameters used by ArgMaxLayer +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; +} + +// Message that stores parameters used by ConcatLayer +message ConcatParameter { + // Concat Layer needs to specify the dimension along the concat will happen, + // the other dimensions must be the same for all the bottom blobs + // By default it will concatenate blobs along channels dimension + optional uint32 concat_dim = 1 [default = 1]; +} + +// Message that stores parameters used by ContrastiveLossLayer +message ContrastiveLossParameter { + //margin for dissimilar pair + optional float margin = 1 [default = 1.0]; +} + +// Message that stores parameters used by ConvolutionLayer +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 3 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 4; // The kernel size (square) + optional uint32 kernel_h = 11; // The kernel height + optional uint32 kernel_w = 12; // The kernel width + optional uint32 group = 5 [default = 1]; // The group size for group conv + optional uint32 stride = 6 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 13; // The stride height + optional uint32 stride_w = 14; // The stride width + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} + + +// Message that stores parameters used by MultiStageMeanfieldLayer +message MultiStageMeanfieldParameter { + enum Mode { + POTTS = 0; + } + optional Mode compatibility_mode = 1 [default = POTTS]; + optional float threshold = 2; + + required float theta_alpha = 3 [default = 10.]; + required float theta_beta = 4 [default = 10.]; + required float theta_gamma = 5 [default = 10.]; + + required uint32 num_iterations = 6 [default = 1]; + optional float spatial_filter_weight = 7 [default = 1]; + optional float bilateral_filter_weight = 8 [default = 1]; + + optional float forced_spatial_filter_weight = 9; + optional float forced_bilateral_filter_weight = 10; +} + + + + + +// Message that stores parameters used by DataLayer +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; +} + +// Message that stores parameters used by DropoutLayer +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio +} + +// Message that stores parameters used by DummyDataLayer. +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // num, N channels, N height, and N width fields, and must specify 0, 1 or N + // data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +// Message that stores parameters used by EltwiseLayer +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; +} + +// Message that stores parameters used by HDF5OutputLayer +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +// Message that stores parameters used by ImageDataLayer +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +// Message that stores parameters InfogainLossLayer +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; +} + +// Message that stores parameters used by InnerProductLayer +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; +} + +// Message that stores parameters used by MemoryDataLayer +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +// Message that stores parameters used by MVNLayer +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; +} + +// Message that stores parameters used by PoolingLayer +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; +} + +// Message that stores parameters used by PowerLayer +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +// Message that stores parameters used by SigmoidLayer +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by SliceLayer +message SliceParameter { + // SliceLayer needs to know which dimension to slice across. + // Currently, SliceLayer only supports slicing across num (dim 0) + // and channels (dim 1). + // By default, SliceLayer slices across channels. + optional uint32 slice_dim = 1 [default = 1]; + repeated uint32 slice_point = 2; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TanHLayer +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +// Message that stores parameters used by WindowDataLayer +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} diff --git a/caffe-crfrnn/src/caffe/proto/caffe_pretty_print.proto b/caffe-crfrnn/src/caffe/proto/caffe_pretty_print.proto new file mode 100644 index 00000000..6f0a5f6b --- /dev/null +++ b/caffe-crfrnn/src/caffe/proto/caffe_pretty_print.proto @@ -0,0 +1,18 @@ +syntax = "proto2"; + +package caffe; + +import "caffe.proto"; + +// A near-duplicate of NetParameter with fields re-numbered to beautify +// automatic prototext dumps. The main practical purpose is to print inputs +// before layers, because having inputs at the end looks weird. +// NetParameterPrettyPrint should never be used in code except for conversion +// FROM NetParameter and subsequent dumping to proto text file. +message NetParameterPrettyPrint { + optional string name = 1; + optional bool force_backward = 2 [default = false]; + repeated string input = 3; + repeated int32 input_dim = 4; + repeated LayerParameter layers = 5; +} diff --git a/caffe-crfrnn/src/caffe/solver.cpp b/caffe-crfrnn/src/caffe/solver.cpp new file mode 100644 index 00000000..349f3e47 --- /dev/null +++ b/caffe-crfrnn/src/caffe/solver.cpp @@ -0,0 +1,833 @@ +#include + +#include +#include +#include + +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/upgrade_proto.hpp" + +namespace caffe { + +template +Solver::Solver(const SolverParameter& param) + : net_() { + Init(param); +} + +template +Solver::Solver(const string& param_file) + : net_() { + SolverParameter param; + ReadProtoFromTextFileOrDie(param_file, ¶m); + Init(param); +} + +template +void Solver::Init(const SolverParameter& param) { + LOG(INFO) << "Initializing solver from parameters: " << std::endl + << param.DebugString(); + param_ = param; + CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; + if (param_.random_seed() >= 0) { + Caffe::set_random_seed(param_.random_seed()); + } + // Scaffolding code + InitTrainNet(); + InitTestNets(); + LOG(INFO) << "Solver scaffolding done."; + iter_ = 0; + current_step_ = 0; +} + +template +void Solver::InitTrainNet() { + const int num_train_nets = param_.has_net() + param_.has_net_param() + + param_.has_train_net() + param_.has_train_net_param(); + const string& field_names = "net, net_param, train_net, train_net_param"; + CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net " + << "using one of these fields: " << field_names; + CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than " + << "one of these fields specifying a train_net: " << field_names; + NetParameter net_param; + if (param_.has_train_net_param()) { + LOG(INFO) << "Creating training net specified in train_net_param."; + net_param.CopyFrom(param_.train_net_param()); + } else if (param_.has_train_net()) { + LOG(INFO) << "Creating training net from train_net file: " + << param_.train_net(); + ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); + } + if (param_.has_net_param()) { + LOG(INFO) << "Creating training net specified in net_param."; + net_param.CopyFrom(param_.net_param()); + } + if (param_.has_net()) { + LOG(INFO) << "Creating training net from net file: " << param_.net(); + ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); + } + // Set the correct NetState. We start with the solver defaults (lowest + // precedence); then, merge in any NetState specified by the net_param itself; + // finally, merge in any NetState specified by the train_state (highest + // precedence). + NetState net_state; + net_state.set_phase(TRAIN); + net_state.MergeFrom(net_param.state()); + net_state.MergeFrom(param_.train_state()); + net_param.mutable_state()->CopyFrom(net_state); + net_.reset(new Net(net_param)); +} + +template +void Solver::InitTestNets() { + const bool has_net_param = param_.has_net_param(); + const bool has_net_file = param_.has_net(); + const int num_generic_nets = has_net_param + has_net_file; + CHECK_LE(num_generic_nets, 1) + << "Both net_param and net_file may not be specified."; + const int num_test_net_params = param_.test_net_param_size(); + const int num_test_net_files = param_.test_net_size(); + const int num_test_nets = num_test_net_params + num_test_net_files; + if (num_generic_nets) { + CHECK_GE(param_.test_iter_size(), num_test_nets) + << "test_iter must be specified for each test network."; + } else { + CHECK_EQ(param_.test_iter_size(), num_test_nets) + << "test_iter must be specified for each test network."; + } + // If we have a generic net (specified by net or net_param, rather than + // test_net or test_net_param), we may have an unlimited number of actual + // test networks -- the actual number is given by the number of remaining + // test_iters after any test nets specified by test_net_param and/or test_net + // are evaluated. + const int num_generic_net_instances = param_.test_iter_size() - num_test_nets; + const int num_test_net_instances = num_test_nets + num_generic_net_instances; + if (param_.test_state_size()) { + CHECK_EQ(param_.test_state_size(), num_test_net_instances) + << "test_state must be unspecified or specified once per test net."; + } + if (num_test_net_instances) { + CHECK_GT(param_.test_interval(), 0); + } + int test_net_id = 0; + vector sources(num_test_net_instances); + vector net_params(num_test_net_instances); + for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) { + sources[test_net_id] = "test_net_param"; + net_params[test_net_id].CopyFrom(param_.test_net_param(i)); + } + for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) { + sources[test_net_id] = "test_net file: " + param_.test_net(i); + ReadNetParamsFromTextFileOrDie(param_.test_net(i), + &net_params[test_net_id]); + } + const int remaining_test_nets = param_.test_iter_size() - test_net_id; + if (has_net_param) { + for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { + sources[test_net_id] = "net_param"; + net_params[test_net_id].CopyFrom(param_.net_param()); + } + } + if (has_net_file) { + for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) { + sources[test_net_id] = "net file: " + param_.net(); + ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]); + } + } + test_nets_.resize(num_test_net_instances); + for (int i = 0; i < num_test_net_instances; ++i) { + // Set the correct NetState. We start with the solver defaults (lowest + // precedence); then, merge in any NetState specified by the net_param + // itself; finally, merge in any NetState specified by the test_state + // (highest precedence). + NetState net_state; + net_state.set_phase(TEST); + net_state.MergeFrom(net_params[i].state()); + if (param_.test_state_size()) { + net_state.MergeFrom(param_.test_state(i)); + } + net_params[i].mutable_state()->CopyFrom(net_state); + LOG(INFO) + << "Creating test net (#" << i << ") specified by " << sources[i]; + test_nets_[i].reset(new Net(net_params[i])); + } +} + +template +void Solver::Step(int iters) { + vector*> bottom_vec; + const int start_iter = iter_; + const int stop_iter = iter_ + iters; + int average_loss = this->param_.average_loss(); + vector losses; + Dtype smoothed_loss = 0; + + for (; iter_ < stop_iter; ++iter_) { + // zero-init the params + for (int i = 0; i < net_->params().size(); ++i) { + shared_ptr > blob = net_->params()[i]; + switch (Caffe::mode()) { + case Caffe::CPU: + caffe_set(blob->count(), static_cast(0), + blob->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + caffe_gpu_set(blob->count(), static_cast(0), + blob->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + } + + if (param_.test_interval() && iter_ % param_.test_interval() == 0 + && (iter_ > 0 || param_.test_initialization())) { + TestAll(); + } + + const bool display = param_.display() && iter_ % param_.display() == 0; + net_->set_debug_info(display && param_.debug_info()); + // accumulate the loss and gradient + Dtype loss = 0; + for (int i = 0; i < param_.iter_size(); ++i) { + loss += net_->ForwardBackward(bottom_vec); + } + loss /= param_.iter_size(); + // average the loss across iterations for smoothed reporting + if (losses.size() < average_loss) { + losses.push_back(loss); + int size = losses.size(); + smoothed_loss = (smoothed_loss * (size - 1) + loss) / size; + } else { + int idx = (iter_ - start_iter) % average_loss; + smoothed_loss += (loss - losses[idx]) / average_loss; + losses[idx] = loss; + } + if (display) { + LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + const vector*>& result = net_->output_blobs(); + int score_index = 0; + for (int j = 0; j < result.size(); ++j) { + const Dtype* result_vec = result[j]->cpu_data(); + const string& output_name = + net_->blob_names()[net_->output_blob_indices()[j]]; + const Dtype loss_weight = + net_->blob_loss_weights()[net_->output_blob_indices()[j]]; + for (int k = 0; k < result[j]->count(); ++k) { + ostringstream loss_msg_stream; + if (loss_weight) { + loss_msg_stream << " (* " << loss_weight + << " = " << loss_weight * result_vec[k] << " loss)"; + } + LOG(INFO) << " Train net output #" + << score_index++ << ": " << output_name << " = " + << result_vec[k] << loss_msg_stream.str(); + } + } + } + ComputeUpdateValue(); + net_->Update(); + + // Save a snapshot if needed. + if (param_.snapshot() && (iter_ + 1) % param_.snapshot() == 0) { + Snapshot(); + } + } +} + +template +void Solver::Solve(const char* resume_file) { + Caffe::set_phase(Caffe::TRAIN); + LOG(INFO) << "Solving " << net_->name(); + LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); + + if (resume_file) { + LOG(INFO) << "Restoring previous solver status from " << resume_file; + Restore(resume_file); + } + + // For a network that is trained by the solver, no bottom or top vecs + // should be given, and we will just provide dummy vecs. + Step(param_.max_iter() - iter_); + // If we haven't already, save a snapshot after optimization, unless + // overridden by setting snapshot_after_train := false + if (param_.snapshot_after_train() + && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { + Snapshot(); + } + // After the optimization is done, run an additional train and test pass to + // display the train and test loss/outputs if appropriate (based on the + // display and test_interval settings, respectively). Unlike in the rest of + // training, for the train net we only run a forward pass as we've already + // updated the parameters "max_iter" times -- this final pass is only done to + // display the loss, which is computed in the forward pass. + if (param_.display() && iter_ % param_.display() == 0) { + Dtype loss; + net_->ForwardPrefilled(&loss); + LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; + } + if (param_.test_interval() && iter_ % param_.test_interval() == 0) { + TestAll(); + } + LOG(INFO) << "Optimization Done."; +} + + +template +void Solver::TestAll() { + for (int test_net_id = 0; test_net_id < test_nets_.size(); ++test_net_id) { + Test(test_net_id); + } +} + +template +void Solver::Test(const int test_net_id) { + LOG(INFO) << "Iteration " << iter_ + << ", Testing net (#" << test_net_id << ")"; + // We need to set phase to test before running. + Caffe::set_phase(Caffe::TEST); + CHECK_NOTNULL(test_nets_[test_net_id].get())-> + ShareTrainedLayersWith(net_.get()); + vector test_score; + vector test_score_output_id; + vector*> bottom_vec; + const shared_ptr >& test_net = test_nets_[test_net_id]; + Dtype loss = 0; + for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + Dtype iter_loss; + const vector*>& result = + test_net->Forward(bottom_vec, &iter_loss); + if (param_.test_compute_loss()) { + loss += iter_loss; + } + if (i == 0) { + for (int j = 0; j < result.size(); ++j) { + const Dtype* result_vec = result[j]->cpu_data(); + for (int k = 0; k < result[j]->count(); ++k) { + test_score.push_back(result_vec[k]); + test_score_output_id.push_back(j); + } + } + } else { + int idx = 0; + for (int j = 0; j < result.size(); ++j) { + const Dtype* result_vec = result[j]->cpu_data(); + for (int k = 0; k < result[j]->count(); ++k) { + test_score[idx++] += result_vec[k]; + } + } + } + } + if (param_.test_compute_loss()) { + loss /= param_.test_iter(test_net_id); + LOG(INFO) << "Test loss: " << loss; + } + for (int i = 0; i < test_score.size(); ++i) { + const int output_blob_index = + test_net->output_blob_indices()[test_score_output_id[i]]; + const string& output_name = test_net->blob_names()[output_blob_index]; + const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; + ostringstream loss_msg_stream; + const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); + if (loss_weight) { + loss_msg_stream << " (* " << loss_weight + << " = " << loss_weight * mean_score << " loss)"; + } + LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " + << mean_score << loss_msg_stream.str(); + } + Caffe::set_phase(Caffe::TRAIN); +} + + +template +void Solver::Snapshot() { + NetParameter net_param; + // For intermediate results, we will also dump the gradient values. + net_->ToProto(&net_param, param_.snapshot_diff()); + string filename(param_.snapshot_prefix()); + string model_filename, snapshot_filename; + const int kBufferSize = 20; + char iter_str_buffer[kBufferSize]; + // Add one to iter_ to get the number of iterations that have completed. + snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_ + 1); + filename += iter_str_buffer; + model_filename = filename + ".caffemodel"; + LOG(INFO) << "Snapshotting to " << model_filename; + WriteProtoToBinaryFile(net_param, model_filename.c_str()); + SolverState state; + SnapshotSolverState(&state); + state.set_iter(iter_ + 1); + state.set_learned_net(model_filename); + state.set_current_step(current_step_); + snapshot_filename = filename + ".solverstate"; + LOG(INFO) << "Snapshotting solver state to " << snapshot_filename; + WriteProtoToBinaryFile(state, snapshot_filename.c_str()); +} + +template +void Solver::Restore(const char* state_file) { + SolverState state; + NetParameter net_param; + ReadProtoFromBinaryFile(state_file, &state); + if (state.has_learned_net()) { + ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param); + net_->CopyTrainedLayersFrom(net_param); + } + iter_ = state.iter(); + current_step_ = state.current_step(); + RestoreSolverState(state); +} + + +// Return the current learning rate. The currently implemented learning rate +// policies are as follows: +// - fixed: always return base_lr. +// - step: return base_lr * gamma ^ (floor(iter / step)) +// - exp: return base_lr * gamma ^ iter +// - inv: return base_lr * (1 + gamma * iter) ^ (- power) +// - multistep: similar to step but it allows non uniform steps defined by +// stepvalue +// - poly: the effective learning rate follows a polynomial decay, to be +// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) +// - sigmoid: the effective learning rate follows a sigmod decay +// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) +// +// where base_lr, max_iter, gamma, step, stepvalue and power are defined +// in the solver parameter protocol buffer, and iter is the current iteration. +template +Dtype SGDSolver::GetLearningRate() { + Dtype rate; + const string& lr_policy = this->param_.lr_policy(); + if (lr_policy == "fixed") { + rate = this->param_.base_lr(); + } else if (lr_policy == "step") { + this->current_step_ = this->iter_ / this->param_.stepsize(); + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "exp") { + rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); + } else if (lr_policy == "inv") { + rate = this->param_.base_lr() * + pow(Dtype(1) + this->param_.gamma() * this->iter_, + - this->param_.power()); + } else if (lr_policy == "multistep") { + if (this->current_step_ < this->param_.stepvalue_size() && + this->iter_ >= this->param_.stepvalue(this->current_step_)) { + this->current_step_++; + LOG(INFO) << "MultiStep Status: Iteration " << + this->iter_ << ", step = " << this->current_step_; + } + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "poly") { + rate = this->param_.base_lr() * pow(Dtype(1.) - + (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + this->param_.power()); + } else if (lr_policy == "sigmoid") { + rate = this->param_.base_lr() * (Dtype(1.) / + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + Dtype(this->param_.stepsize()))))); + } else { + LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; + } + return rate; +} + +template +void SGDSolver::PreSolve() { + // Initialize the history + vector > >& net_params = this->net_->params(); + history_.clear(); + update_.clear(); + temp_.clear(); + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + update_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + temp_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } +} + + +template +void SGDSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // Compute the value to history, and then copy them to the blob's diff. + Dtype local_rate = rate * net_params_lr[param_id] + / this->param_.iter_size(); + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + history_[param_id]->mutable_cpu_data()); + // copy + caffe_copy(net_params[param_id]->count(), + history_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // Compute the value to history, and then copy them to the blob's diff. + Dtype local_rate = rate * net_params_lr[param_id] + / this->param_.iter_size(); + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + history_[param_id]->mutable_gpu_data()); + // copy + caffe_copy(net_params[param_id]->count(), + history_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::SnapshotSolverState(SolverState* state) { + state->clear_history(); + for (int i = 0; i < history_.size(); ++i) { + // Add history + BlobProto* history_blob = state->add_history(); + history_[i]->ToProto(history_blob); + } +} + +template +void SGDSolver::RestoreSolverState(const SolverState& state) { + CHECK_EQ(state.history_size(), history_.size()) + << "Incorrect length of history blobs."; + LOG(INFO) << "SGDSolver: restoring history"; + for (int i = 0; i < history_.size(); ++i) { + history_[i]->FromProto(state.history(i)); + } +} + +template +void NesterovSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // update history + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // compute udpate: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // update history + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // compute udpate: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void AdaGradSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_lr = this->net_->params_lr(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + // get the learning rate + Dtype rate = this->GetLearningRate(); + Dtype delta = this->param_.delta(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_rate = rate * net_params_lr[param_id]; + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(Solver); +INSTANTIATE_CLASS(SGDSolver); +INSTANTIATE_CLASS(NesterovSolver); +INSTANTIATE_CLASS(AdaGradSolver); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/syncedmem.cpp b/caffe-crfrnn/src/caffe/syncedmem.cpp new file mode 100644 index 00000000..7617ccfb --- /dev/null +++ b/caffe-crfrnn/src/caffe/syncedmem.cpp @@ -0,0 +1,113 @@ +#include + +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +SyncedMemory::~SyncedMemory() { + if (cpu_ptr_ && own_cpu_data_) { + CaffeFreeHost(cpu_ptr_); + } + +#ifndef CPU_ONLY + if (gpu_ptr_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + } +#endif // CPU_ONLY +} + +inline void SyncedMemory::to_cpu() { + switch (head_) { + case UNINITIALIZED: + CaffeMallocHost(&cpu_ptr_, size_); + caffe_memset(size_, 0, cpu_ptr_); + head_ = HEAD_AT_CPU; + own_cpu_data_ = true; + break; + case HEAD_AT_GPU: +#ifndef CPU_ONLY + if (cpu_ptr_ == NULL) { + CaffeMallocHost(&cpu_ptr_, size_); + own_cpu_data_ = true; + } + caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_); + head_ = SYNCED; +#else + NO_GPU; +#endif + break; + case HEAD_AT_CPU: + case SYNCED: + break; + } +} + +inline void SyncedMemory::to_gpu() { +#ifndef CPU_ONLY + switch (head_) { + case UNINITIALIZED: + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + caffe_gpu_memset(size_, 0, gpu_ptr_); + head_ = HEAD_AT_GPU; + break; + case HEAD_AT_CPU: + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + } + caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); + head_ = SYNCED; + break; + case HEAD_AT_GPU: + case SYNCED: + break; + } +#else + NO_GPU; +#endif +} + +const void* SyncedMemory::cpu_data() { + to_cpu(); + return (const void*)cpu_ptr_; +} + +void SyncedMemory::set_cpu_data(void* data) { + CHECK(data); + if (own_cpu_data_) { + CaffeFreeHost(cpu_ptr_); + } + cpu_ptr_ = data; + head_ = HEAD_AT_CPU; + own_cpu_data_ = false; +} + +const void* SyncedMemory::gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + return (const void*)gpu_ptr_; +#else + NO_GPU; +#endif +} + +void* SyncedMemory::mutable_cpu_data() { + to_cpu(); + head_ = HEAD_AT_CPU; + return cpu_ptr_; +} + +void* SyncedMemory::mutable_gpu_data() { +#ifndef CPU_ONLY + to_gpu(); + head_ = HEAD_AT_GPU; + return gpu_ptr_; +#else + NO_GPU; +#endif +} + + +} // namespace caffe + diff --git a/caffe-crfrnn/src/caffe/test/CMakeLists.txt b/caffe-crfrnn/src/caffe/test/CMakeLists.txt new file mode 100644 index 00000000..35a803f2 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/CMakeLists.txt @@ -0,0 +1,36 @@ +# The option allows to include in build only selected test files and exclude all others +# Usage example: +# cmake -DBUILD_only_tests="common,net,blob,im2col_kernel" +set(BUILD_only_tests "" CACHE STRING "Blank or comma-separated list of test files to build without 'test_' prefix and extention") +caffe_leave_only_selected_tests(test_srcs ${BUILD_only_tests}) +caffe_leave_only_selected_tests(test_cuda ${BUILD_only_tests}) + +# For 'make runtest' target we don't need to embed test data paths to +# source files, because test target is executed in source directory +# That's why the lines below are commented. TODO: remove them + +# definition needed to include CMake generated files +#add_definitions(-DCMAKE_BUILD) + +# generates test_data/sample_data_list.txt.gen.cmake +#caffe_configure_testdatafile(test_data/sample_data_list.txt) + +set(the_target test.testbin) +set(test_args --gtest_shuffle) + +if(HAVE_CUDA) + caffe_cuda_compile(test_cuda_objs ${test_cuda}) + list(APPEND test_srcs ${test_cuda_objs} ${test_cuda}) +else() + list(APPEND test_args --gtest_filter="-*GPU*") +endif() + +# ---[ Adding test target +add_executable(${the_target} EXCLUDE_FROM_ALL ${test_srcs}) +target_link_libraries(${the_target} gtest ${Caffe_LINK}) +caffe_default_properties(${the_target}) +caffe_set_runtime_directory(${the_target} "${PROJECT_BINARY_DIR}/test") + +# ---[ Adding runtest +add_custom_target(runtest COMMAND ${the_target} ${test_args} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) diff --git a/caffe-crfrnn/src/caffe/test/cmake_test_defines.hpp.in b/caffe-crfrnn/src/caffe/test/cmake_test_defines.hpp.in new file mode 100644 index 00000000..870eaf5c --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/cmake_test_defines.hpp.in @@ -0,0 +1,4 @@ +#define CUDA_TEST_DEVICE @CUDA_TEST_DEVICE@ +#define CMAKE_SOURCE_DIR "@CMAKE_SOURCE_DIR@/src/" +#define EXAMPLES_SOURCE_DIR "@CMAKE_SOURCE_DIR@/examples/" +#define CMAKE_EXT ".gen.cmake" diff --git a/caffe-crfrnn/src/caffe/test/test_accuracy_layer.cpp b/caffe-crfrnn/src/caffe/test/test_accuracy_layer.cpp new file mode 100644 index 00000000..fa59fab1 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_accuracy_layer.cpp @@ -0,0 +1,140 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/rng.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class AccuracyLayerTest : public ::testing::Test { + protected: + AccuracyLayerTest() + : blob_bottom_data_(new Blob(100, 10, 1, 1)), + blob_bottom_label_(new Blob(100, 1, 1, 1)), + blob_top_(new Blob()), + top_k_(3) { + // fill the probability values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + + const unsigned int prefetch_rng_seed = caffe_rng_rand(); + shared_ptr rng(new Caffe::RNG(prefetch_rng_seed)); + caffe::rng_t* prefetch_rng = + static_cast(rng->generator()); + Dtype* label_data = blob_bottom_label_->mutable_cpu_data(); + for (int i = 0; i < 100; ++i) { + label_data[i] = (*prefetch_rng)() % 10; + } + + blob_bottom_vec_.push_back(blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~AccuracyLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + int top_k_; +}; + +TYPED_TEST_CASE(AccuracyLayerTest, TestDtypes); + +TYPED_TEST(AccuracyLayerTest, TestSetup) { + LayerParameter layer_param; + AccuracyLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); +} + +TYPED_TEST(AccuracyLayerTest, TestSetupTopK) { + LayerParameter layer_param; + AccuracyParameter* accuracy_param = + layer_param.mutable_accuracy_param(); + accuracy_param->set_top_k(5); + AccuracyLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); +} + +TYPED_TEST(AccuracyLayerTest, TestForwardCPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + AccuracyLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + TypeParam max_value; + int max_id; + int num_correct_labels = 0; + for (int i = 0; i < 100; ++i) { + max_value = -FLT_MAX; + max_id = 0; + for (int j = 0; j < 10; ++j) { + if (this->blob_bottom_data_->data_at(i, j, 0, 0) > max_value) { + max_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + max_id = j; + } + } + if (max_id == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; + } + } + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / 100.0, 1e-4); +} + +TYPED_TEST(AccuracyLayerTest, TestForwardCPUTopK) { + LayerParameter layer_param; + AccuracyParameter* accuracy_param = layer_param.mutable_accuracy_param(); + accuracy_param->set_top_k(this->top_k_); + AccuracyLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + TypeParam current_value; + int current_rank; + int num_correct_labels = 0; + for (int i = 0; i < 100; ++i) { + for (int j = 0; j < 10; ++j) { + current_value = this->blob_bottom_data_->data_at(i, j, 0, 0); + current_rank = 0; + for (int k = 0; k < 10; ++k) { + if (this->blob_bottom_data_->data_at(i, k, 0, 0) > current_value) { + ++current_rank; + } + } + if (current_rank < this->top_k_ && + j == this->blob_bottom_label_->data_at(i, 0, 0, 0)) { + ++num_correct_labels; + } + } + } + + EXPECT_NEAR(this->blob_top_->data_at(0, 0, 0, 0), + num_correct_labels / 100.0, 1e-4); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_argmax_layer.cpp b/caffe-crfrnn/src/caffe/test/test_argmax_layer.cpp new file mode 100644 index 00000000..3487d42f --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_argmax_layer.cpp @@ -0,0 +1,169 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class ArgMaxLayerTest : public ::testing::Test { + protected: + ArgMaxLayerTest() + : blob_bottom_(new Blob(10, 20, 1, 1)), + blob_top_(new Blob()), + top_k_(5) { + Caffe::set_mode(Caffe::CPU); + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~ArgMaxLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + size_t top_k_; +}; + +TYPED_TEST_CASE(ArgMaxLayerTest, TestDtypes); + +TYPED_TEST(ArgMaxLayerTest, TestSetup) { + LayerParameter layer_param; + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), 1); +} + +TYPED_TEST(ArgMaxLayerTest, TestSetupMaxVal) { + LayerParameter layer_param; + ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param(); + argmax_param->set_out_max_val(true); + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), 2); +} + +TYPED_TEST(ArgMaxLayerTest, TestCPU) { + LayerParameter layer_param; + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + int max_ind; + TypeParam max_val; + int num = this->blob_bottom_->num(); + int dim = this->blob_bottom_->count() / num; + for (int i = 0; i < num; ++i) { + EXPECT_GE(top_data[i], 0); + EXPECT_LE(top_data[i], dim); + max_ind = top_data[i]; + max_val = bottom_data[i * dim + max_ind]; + for (int j = 0; j < dim; ++j) { + EXPECT_LE(bottom_data[i * dim + j], max_val); + } + } +} + +TYPED_TEST(ArgMaxLayerTest, TestCPUMaxVal) { + LayerParameter layer_param; + ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param(); + argmax_param->set_out_max_val(true); + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + int max_ind; + TypeParam max_val; + int num = this->blob_bottom_->num(); + int dim = this->blob_bottom_->count() / num; + for (int i = 0; i < num; ++i) { + EXPECT_GE(top_data[i], 0); + EXPECT_LE(top_data[i], dim); + max_ind = top_data[i * 2]; + max_val = top_data[i * 2 + 1]; + EXPECT_EQ(bottom_data[i * dim + max_ind], max_val); + for (int j = 0; j < dim; ++j) { + EXPECT_LE(bottom_data[i * dim + j], max_val); + } + } +} + +TYPED_TEST(ArgMaxLayerTest, TestCPUTopK) { + LayerParameter layer_param; + ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param(); + argmax_param->set_top_k(this->top_k_); + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + int max_ind; + TypeParam max_val; + int num = this->blob_bottom_->num(); + int dim = this->blob_bottom_->count() / num; + for (int i = 0; i < num; ++i) { + EXPECT_GE(this->blob_top_->data_at(i, 0, 0, 0), 0); + EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim); + for (int j = 0; j < this->top_k_; ++j) { + max_ind = this->blob_top_->data_at(i, 0, j, 0); + max_val = this->blob_bottom_->data_at(i, max_ind, 0, 0); + int count = 0; + for (int k = 0; k < dim; ++k) { + if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) { + ++count; + } + } + EXPECT_EQ(j, count); + } + } +} + +TYPED_TEST(ArgMaxLayerTest, TestCPUMaxValTopK) { + LayerParameter layer_param; + ArgMaxParameter* argmax_param = layer_param.mutable_argmax_param(); + argmax_param->set_out_max_val(true); + argmax_param->set_top_k(this->top_k_); + ArgMaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + int max_ind; + TypeParam max_val; + int num = this->blob_bottom_->num(); + int dim = this->blob_bottom_->count() / num; + for (int i = 0; i < num; ++i) { + EXPECT_GE(this->blob_top_->data_at(i, 0, 0, 0), 0); + EXPECT_LE(this->blob_top_->data_at(i, 0, 0, 0), dim); + for (int j = 0; j < this->top_k_; ++j) { + max_ind = this->blob_top_->data_at(i, 0, j, 0); + max_val = this->blob_top_->data_at(i, 1, j, 0); + EXPECT_EQ(this->blob_bottom_->data_at(i, max_ind, 0, 0), max_val); + int count = 0; + for (int k = 0; k < dim; ++k) { + if (this->blob_bottom_->data_at(i, k, 0, 0) > max_val) { + ++count; + } + } + EXPECT_EQ(j, count); + } + } +} + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_benchmark.cpp b/caffe-crfrnn/src/caffe/test/test_benchmark.cpp new file mode 100644 index 00000000..43aaa639 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_benchmark.cpp @@ -0,0 +1,90 @@ +#include // for usleep + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/util/benchmark.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +const float kMillisecondsThreshold = 30; + +template +class BenchmarkTest : public MultiDeviceTest {}; + +TYPED_TEST_CASE(BenchmarkTest, TestDtypesAndDevices); + +TYPED_TEST(BenchmarkTest, TestTimerConstructor) { + Timer timer; + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_FALSE(timer.has_run_at_least_once()); +} + +TYPED_TEST(BenchmarkTest, TestTimerStart) { + Timer timer; + timer.Start(); + EXPECT_TRUE(timer.initted()); + EXPECT_TRUE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); + timer.Start(); + EXPECT_TRUE(timer.initted()); + EXPECT_TRUE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); + timer.Stop(); + timer.Start(); + EXPECT_TRUE(timer.initted()); + EXPECT_TRUE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); +} + +TYPED_TEST(BenchmarkTest, TestTimerStop) { + Timer timer; + timer.Stop(); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_FALSE(timer.has_run_at_least_once()); + timer.Start(); + timer.Stop(); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); + timer.Stop(); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); +} + +TYPED_TEST(BenchmarkTest, TestTimerMilliSeconds) { + Timer timer; + EXPECT_EQ(timer.MilliSeconds(), 0); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_FALSE(timer.has_run_at_least_once()); + timer.Start(); + usleep(300 * 1000); + EXPECT_GE(timer.MilliSeconds(), 300 - kMillisecondsThreshold); + EXPECT_LE(timer.MilliSeconds(), 300 + kMillisecondsThreshold); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); +} + +TYPED_TEST(BenchmarkTest, TestTimerSeconds) { + Timer timer; + EXPECT_EQ(timer.Seconds(), 0); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_FALSE(timer.has_run_at_least_once()); + timer.Start(); + usleep(300 * 1000); + EXPECT_GE(timer.Seconds(), 0.3 - kMillisecondsThreshold / 1000.); + EXPECT_LE(timer.Seconds(), 0.3 + kMillisecondsThreshold / 1000.); + EXPECT_TRUE(timer.initted()); + EXPECT_FALSE(timer.running()); + EXPECT_TRUE(timer.has_run_at_least_once()); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_blob.cpp b/caffe-crfrnn/src/caffe/test/test_blob.cpp new file mode 100644 index 00000000..adf7a4d3 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_blob.cpp @@ -0,0 +1,57 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class BlobSimpleTest : public ::testing::Test { + protected: + BlobSimpleTest() + : blob_(new Blob()), + blob_preshaped_(new Blob(2, 3, 4, 5)) {} + virtual ~BlobSimpleTest() { delete blob_; delete blob_preshaped_; } + Blob* const blob_; + Blob* const blob_preshaped_; +}; + +TYPED_TEST_CASE(BlobSimpleTest, TestDtypes); + +TYPED_TEST(BlobSimpleTest, TestInitialization) { + EXPECT_TRUE(this->blob_); + EXPECT_TRUE(this->blob_preshaped_); + EXPECT_EQ(this->blob_preshaped_->num(), 2); + EXPECT_EQ(this->blob_preshaped_->channels(), 3); + EXPECT_EQ(this->blob_preshaped_->height(), 4); + EXPECT_EQ(this->blob_preshaped_->width(), 5); + EXPECT_EQ(this->blob_preshaped_->count(), 120); + EXPECT_EQ(this->blob_->num(), 0); + EXPECT_EQ(this->blob_->channels(), 0); + EXPECT_EQ(this->blob_->height(), 0); + EXPECT_EQ(this->blob_->width(), 0); + EXPECT_EQ(this->blob_->count(), 0); +} + +TYPED_TEST(BlobSimpleTest, TestPointersCPUGPU) { + EXPECT_TRUE(this->blob_preshaped_->gpu_data()); + EXPECT_TRUE(this->blob_preshaped_->cpu_data()); + EXPECT_TRUE(this->blob_preshaped_->mutable_gpu_data()); + EXPECT_TRUE(this->blob_preshaped_->mutable_cpu_data()); +} + +TYPED_TEST(BlobSimpleTest, TestReshape) { + this->blob_->Reshape(2, 3, 4, 5); + EXPECT_EQ(this->blob_->num(), 2); + EXPECT_EQ(this->blob_->channels(), 3); + EXPECT_EQ(this->blob_->height(), 4); + EXPECT_EQ(this->blob_->width(), 5); + EXPECT_EQ(this->blob_->count(), 120); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_caffe_main.cpp b/caffe-crfrnn/src/caffe/test/test_caffe_main.cpp new file mode 100644 index 00000000..c8caf5ac --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_caffe_main.cpp @@ -0,0 +1,40 @@ +// The main caffe test code. Your test cpp code should include this hpp +// to allow a main function to be compiled into the binary. + +#include "caffe/caffe.hpp" +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { +#ifndef CPU_ONLY + cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif +} + +#ifndef CPU_ONLY +using caffe::CAFFE_TEST_CUDA_PROP; +#endif + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + caffe::GlobalInit(&argc, &argv); +#ifndef CPU_ONLY + // Before starting testing, let's first print out a few cuda defice info. + int device; + cudaGetDeviceCount(&device); + cout << "Cuda number of devices: " << device << endl; + if (argc > 1) { + // Use the given device + device = atoi(argv[1]); + cudaSetDevice(device); + cout << "Setting to use device " << device << endl; + } else if (CUDA_TEST_DEVICE >= 0) { + // Use the device assigned in build configuration; but with a lower priority + device = CUDA_TEST_DEVICE; + } + cudaGetDevice(&device); + cout << "Current device id: " << device << endl; + cudaGetDeviceProperties(&CAFFE_TEST_CUDA_PROP, device); +#endif + // invoke the test. + return RUN_ALL_TESTS(); +} diff --git a/caffe-crfrnn/src/caffe/test/test_common.cpp b/caffe-crfrnn/src/caffe/test/test_common.cpp new file mode 100644 index 00000000..0b3639c7 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_common.cpp @@ -0,0 +1,73 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +class CommonTest : public ::testing::Test {}; + +#ifndef CPU_ONLY // GPU Caffe singleton test. + +TEST_F(CommonTest, TestCublasHandlerGPU) { + int cuda_device_id; + CUDA_CHECK(cudaGetDevice(&cuda_device_id)); + EXPECT_TRUE(Caffe::cublas_handle()); +} + +#endif + +TEST_F(CommonTest, TestBrewMode) { + Caffe::set_mode(Caffe::CPU); + EXPECT_EQ(Caffe::mode(), Caffe::CPU); + Caffe::set_mode(Caffe::GPU); + EXPECT_EQ(Caffe::mode(), Caffe::GPU); +} + +TEST_F(CommonTest, TestPhase) { + Caffe::set_phase(Caffe::TRAIN); + EXPECT_EQ(Caffe::phase(), Caffe::TRAIN); + Caffe::set_phase(Caffe::TEST); + EXPECT_EQ(Caffe::phase(), Caffe::TEST); +} + +TEST_F(CommonTest, TestRandSeedCPU) { + SyncedMemory data_a(10 * sizeof(int)); + SyncedMemory data_b(10 * sizeof(int)); + Caffe::set_random_seed(1701); + caffe_rng_bernoulli(10, 0.5, static_cast(data_a.mutable_cpu_data())); + + Caffe::set_random_seed(1701); + caffe_rng_bernoulli(10, 0.5, static_cast(data_b.mutable_cpu_data())); + + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(static_cast(data_a.cpu_data())[i], + static_cast(data_b.cpu_data())[i]); + } +} + +#ifndef CPU_ONLY // GPU Caffe singleton test. + +TEST_F(CommonTest, TestRandSeedGPU) { + SyncedMemory data_a(10 * sizeof(unsigned int)); + SyncedMemory data_b(10 * sizeof(unsigned int)); + Caffe::set_random_seed(1701); + CURAND_CHECK(curandGenerate(Caffe::curand_generator(), + static_cast(data_a.mutable_gpu_data()), 10)); + Caffe::set_random_seed(1701); + CURAND_CHECK(curandGenerate(Caffe::curand_generator(), + static_cast(data_b.mutable_gpu_data()), 10)); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(((const unsigned int*)(data_a.cpu_data()))[i], + ((const unsigned int*)(data_b.cpu_data()))[i]); + } +} + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_concat_layer.cpp b/caffe-crfrnn/src/caffe/test/test_concat_layer.cpp new file mode 100644 index 00000000..f14f1d2f --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_concat_layer.cpp @@ -0,0 +1,122 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class ConcatLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ConcatLayerTest() + : blob_bottom_0(new Blob(2, 3, 6, 5)), + blob_bottom_1(new Blob(2, 5, 6, 5)), + blob_bottom_2(new Blob(5, 3, 6, 5)), + blob_top_(new Blob()) {} + virtual void SetUp() { + // fill the values + shared_ptr > filler; + FillerParameter filler_param; + filler_param.set_value(1.); + filler.reset(new ConstantFiller(filler_param)); + filler->Fill(this->blob_bottom_0); + filler_param.set_value(2.); + filler.reset(new ConstantFiller(filler_param)); + filler->Fill(this->blob_bottom_1); + filler_param.set_value(3.); + filler.reset(new ConstantFiller(filler_param)); + filler->Fill(this->blob_bottom_2); + blob_bottom_vec_0.push_back(blob_bottom_0); + blob_bottom_vec_0.push_back(blob_bottom_1); + blob_bottom_vec_1.push_back(blob_bottom_0); + blob_bottom_vec_1.push_back(blob_bottom_2); + blob_top_vec_.push_back(blob_top_); + } + + virtual ~ConcatLayerTest() { + delete blob_bottom_0; delete blob_bottom_1; + delete blob_bottom_2; delete blob_top_; + } + + Blob* const blob_bottom_0; + Blob* const blob_bottom_1; + Blob* const blob_bottom_2; + Blob* const blob_top_; + vector*> blob_bottom_vec_0, blob_bottom_vec_1; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ConcatLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ConcatLayerTest, TestSetupNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_concat_param()->set_concat_dim(0); + ConcatLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_1, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), + this->blob_bottom_0->num() + this->blob_bottom_2->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_0->channels()); + EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_0->height()); + EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_0->width()); +} + +TYPED_TEST(ConcatLayerTest, TestSetupChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConcatLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_0, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_0->num()); + EXPECT_EQ(this->blob_top_->channels(), + this->blob_bottom_0->channels()+this->blob_bottom_1->channels()); + EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_0->height()); + EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_0->width()); +} + + +TYPED_TEST(ConcatLayerTest, TestNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConcatLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_0, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_0, this->blob_top_vec_); + for (int n = 0; n < this->blob_top_->num(); ++n) { + for (int c = 0; c < this->blob_bottom_0->channels(); ++c) { + for (int h = 0; h < this->blob_top_->height(); ++h) { + for (int w = 0; w < this->blob_top_->width(); ++w) { + EXPECT_EQ(this->blob_top_->data_at(n, c, h, w), + this->blob_bottom_vec_0[0]->data_at(n, c, h, w)); + } + } + } + for (int c = 0; c < this->blob_bottom_1->channels(); ++c) { + for (int h = 0; h < this->blob_top_->height(); ++h) { + for (int w = 0; w < this->blob_top_->width(); ++w) { + EXPECT_EQ(this->blob_top_->data_at(n, c+3, h, w), + this->blob_bottom_vec_0[1]->data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST(ConcatLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConcatLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradient(&layer, this->blob_bottom_vec_0, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_contrastive_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_contrastive_loss_layer.cpp new file mode 100644 index 00000000..d269fbc2 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_contrastive_loss_layer.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class ContrastiveLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ContrastiveLossLayerTest() + : blob_bottom_data_i_(new Blob(128, 10, 1, 1)), + blob_bottom_data_j_(new Blob(128, 10, 1, 1)), + blob_bottom_y_(new Blob(128, 1, 1, 1)), + blob_top_loss_(new Blob()) { + // fill the values + FillerParameter filler_param; + filler_param.set_mean(0.0); + filler_param.set_std(0.3); // distances~=1.0 to test both sides of margin + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_i_); + blob_bottom_vec_.push_back(blob_bottom_data_i_); + filler.Fill(this->blob_bottom_data_j_); + blob_bottom_vec_.push_back(blob_bottom_data_j_); + for (int i = 0; i < blob_bottom_y_->count(); ++i) { + blob_bottom_y_->mutable_cpu_data()[i] = caffe_rng_rand() % 2; // 0 or 1 + } + blob_bottom_vec_.push_back(blob_bottom_y_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~ContrastiveLossLayerTest() { + delete blob_bottom_data_i_; + delete blob_bottom_data_j_; + delete blob_bottom_y_; + delete blob_top_loss_; + } + + Blob* const blob_bottom_data_i_; + Blob* const blob_bottom_data_j_; + Blob* const blob_bottom_y_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ContrastiveLossLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ContrastiveLossLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ContrastiveLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // manually compute to compare + const Dtype margin = layer_param.contrastive_loss_param().margin(); + const int num = this->blob_bottom_data_i_->num(); + const int channels = this->blob_bottom_data_i_->channels(); + Dtype loss(0); + for (int i = 0; i < num; ++i) { + Dtype dist_sq(0); + for (int j = 0; j < channels; ++j) { + Dtype diff = this->blob_bottom_data_i_->cpu_data()[i*channels+j] - + this->blob_bottom_data_j_->cpu_data()[i*channels+j]; + dist_sq += diff*diff; + } + if (this->blob_bottom_y_->cpu_data()[i]) { // similar pairs + loss += dist_sq; + } else { + loss += std::max(margin-dist_sq, Dtype(0)); + } + } + loss /= static_cast(num) * Dtype(2); + EXPECT_NEAR(this->blob_top_loss_->cpu_data()[0], loss, 1e-6); +} + +TYPED_TEST(ContrastiveLossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ContrastiveLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + GradientChecker checker(1e-2, 1e-2, 1701); + // check the gradient for the first two bottom layers + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 1); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_convolution_layer.cpp b/caffe-crfrnn/src/caffe/test/test_convolution_layer.cpp new file mode 100644 index 00000000..c1fe3b58 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_convolution_layer.cpp @@ -0,0 +1,704 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +// Reference convolution for checking results: +// accumulate through explicit loops over input, output, and filters. +template +void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out) { + // Kernel size, stride, and pad + int kernel_h, kernel_w; + if (conv_param->has_kernel_size()) { + kernel_h = kernel_w = conv_param->kernel_size(); + } else { + kernel_h = conv_param->kernel_h(); + kernel_w = conv_param->kernel_w(); + } + int pad_h, pad_w; + if (!conv_param->has_pad_h()) { + pad_h = pad_w = conv_param->pad(); + } else { + pad_h = conv_param->pad_h(); + pad_w = conv_param->pad_w(); + } + int stride_h, stride_w; + if (!conv_param->has_stride_h()) { + stride_h = stride_w = conv_param->stride(); + } else { + stride_h = conv_param->stride_h(); + stride_w = conv_param->stride_w(); + } + // Groups + int groups = conv_param->group(); + int o_g = out->channels() / groups; + int k_g = in->channels() / groups; + int o_head, k_head; + // Convolution + const Dtype* in_data = in->cpu_data(); + const Dtype* weight_data = weights[0]->cpu_data(); + Dtype* out_data = out->mutable_cpu_data(); + for (int n = 0; n < out->num(); n++) { + for (int g = 0; g < groups; g++) { + o_head = o_g * g; + k_head = k_g * g; + for (int o = 0; o < o_g; o++) { + for (int k = 0; k < k_g; k++) { + for (int y = 0; y < out->height(); y++) { + for (int x = 0; x < out->width(); x++) { + for (int p = 0; p < kernel_h; p++) { + for (int q = 0; q < kernel_w; q++) { + int in_y = y * stride_h - pad_h + p; + int in_x = x * stride_w - pad_w + q; + if (in_y >= 0 && in_y < in->height() + && in_x >= 0 && in_x < in->width()) { + out_data[out->offset(n, o + o_head, y, x)] += + in_data[in->offset(n, k + k_head, in_y, in_x)] + * weight_data[weights[0]->offset(o + o_head, k, p, q)]; + } + } + } + } + } + } + } + } + } + // Bias + if (conv_param->bias_term()) { + const Dtype* bias_data = weights[1]->cpu_data(); + for (int n = 0; n < out->num(); n++) { + for (int o = 0; o < out->channels(); o++) { + for (int y = 0; y < out->height(); y++) { + for (int x = 0; x < out->width(); x++) { + out_data[out->offset(n, o, y, x)] += bias_data[o]; + } + } + } + } + } +} + +template void caffe_conv(const Blob* in, + ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out); +template void caffe_conv(const Blob* in, + ConvolutionParameter* conv_param, + const vector > >& weights, + Blob* out); + +template +class ConvolutionLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ConvolutionLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 4)), + blob_bottom_2_(new Blob(2, 3, 6, 4)), + blob_top_(new Blob()), + blob_top_2_(new Blob()) {} + virtual void SetUp() { + // fill the values + FillerParameter filler_param; + filler_param.set_value(1.); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_bottom_2_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + + virtual ~ConvolutionLayerTest() { + delete blob_bottom_; + delete blob_bottom_2_; + delete blob_top_; + delete blob_top_2_; + } + + virtual Blob* MakeReferenceTop(Blob* top) { + this->ref_blob_top_.reset(new Blob()); + this->ref_blob_top_->ReshapeLike(*top); + return this->ref_blob_top_.get(); + } + + Blob* const blob_bottom_; + Blob* const blob_bottom_2_; + Blob* const blob_top_; + Blob* const blob_top_2_; + shared_ptr > ref_blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ConvolutionLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ConvolutionLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(4); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 4); + EXPECT_EQ(this->blob_top_->height(), 2); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_2_->num(), 2); + EXPECT_EQ(this->blob_top_2_->channels(), 4); + EXPECT_EQ(this->blob_top_2_->height(), 2); + EXPECT_EQ(this->blob_top_2_->width(), 1); + // setting group should not change the shape + convolution_param->set_num_output(3); + convolution_param->set_group(3); + layer.reset(new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3); + EXPECT_EQ(this->blob_top_->height(), 2); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_2_->num(), 2); + EXPECT_EQ(this->blob_top_2_->channels(), 3); + EXPECT_EQ(this->blob_top_2_->height(), 2); + EXPECT_EQ(this->blob_top_2_->width(), 1); +} + +TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) { + typedef typename TypeParam::Dtype Dtype; + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); + top_data = this->blob_top_2_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + +TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(1); + convolution_param->set_stride(1); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + +TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(3); + convolution_param->set_group(3); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const Dtype* top_data; + const Dtype* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + +TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) { + // Test separable convolution by computing the Sobel operator + // as a single filter then comparing the result + // as the convolution of two rectangular filters. + typedef typename TypeParam::Dtype Dtype; + // Fill bottoms with identical Gaussian noise. + shared_ptr > filler; + FillerParameter filler_param; + filler_param.set_value(1.); + filler.reset(new GaussianFiller(filler_param)); + filler->Fill(this->blob_bottom_); + this->blob_bottom_2_->CopyFrom(*this->blob_bottom_); + // Compute Sobel G_x operator as 3 x 3 convolution. + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + shared_ptr > layer( + new ConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 3, 3)); + Dtype* weights = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 9; // 3 x 3 filter + weights[i + 0] = -1; + weights[i + 1] = 0; + weights[i + 2] = 1; + weights[i + 3] = -2; + weights[i + 4] = 0; + weights[i + 5] = 2; + weights[i + 6] = -1; + weights[i + 7] = 0; + weights[i + 8] = 1; + } + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Compute Sobel G_x operator as separable 3 x 1 and 1 x 3 convolutions. + // (1) the [1 2 1] column filter + vector*> sep_blob_bottom_vec; + vector*> sep_blob_top_vec; + shared_ptr > blob_sep(new Blob()); + sep_blob_bottom_vec.push_back(this->blob_bottom_2_); + sep_blob_top_vec.push_back(this->blob_top_2_); + convolution_param->clear_kernel_size(); + convolution_param->clear_stride(); + convolution_param->set_kernel_h(3); + convolution_param->set_kernel_w(1); + convolution_param->set_stride_h(2); + convolution_param->set_stride_w(1); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + layer.reset(new ConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 3, 1)); + Dtype* weights_1 = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 3; // 3 x 1 filter + weights_1[i + 0] = 1; + weights_1[i + 1] = 2; + weights_1[i + 2] = 1; + } + layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); + layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); + // (2) the [-1 0 1] row filter + blob_sep->CopyFrom(*this->blob_top_2_, false, true); + sep_blob_bottom_vec.clear(); + sep_blob_bottom_vec.push_back(blob_sep.get()); + convolution_param->set_kernel_h(1); + convolution_param->set_kernel_w(3); + convolution_param->set_stride_h(1); + convolution_param->set_stride_w(2); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + layer.reset(new ConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 1, 3)); + Dtype* weights_2 = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 3; // 1 x 3 filter + weights_2[i + 0] = -1; + weights_2[i + 1] = 0; + weights_2[i + 2] = 1; + } + layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); + layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); + // Test equivalence of full and separable filters. + const Dtype* top_data = this->blob_top_->cpu_data(); + const Dtype* sep_top_data = this->blob_top_2_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4); + } +} + +TYPED_TEST(ConvolutionLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + ConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ConvolutionLayerTest, Test1x1Gradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + convolution_param->set_kernel_size(1); + convolution_param->set_stride(1); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + ConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(3); + convolution_param->set_group(3); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + ConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#ifdef USE_CUDNN + +template +class CuDNNConvolutionLayerTest : public ::testing::Test { + protected: + CuDNNConvolutionLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 4)), + blob_bottom_2_(new Blob(2, 3, 6, 4)), + blob_top_(new Blob()), + blob_top_2_(new Blob()) {} + virtual void SetUp() { + // fill the values + FillerParameter filler_param; + filler_param.set_value(1.); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_bottom_2_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + + virtual ~CuDNNConvolutionLayerTest() { + delete blob_bottom_; + delete blob_bottom_2_; + delete blob_top_; + delete blob_top_2_; + } + + virtual Blob* MakeReferenceTop(Blob* top) { + this->ref_blob_top_.reset(new Blob()); + this->ref_blob_top_->ReshapeLike(*top); + return this->ref_blob_top_.get(); + } + + Blob* const blob_bottom_; + Blob* const blob_bottom_2_; + Blob* const blob_top_; + Blob* const blob_top_2_; + shared_ptr > ref_blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(CuDNNConvolutionLayerTest, TestDtypes); + +TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) { + Caffe::set_mode(Caffe::GPU); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(4); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + shared_ptr > layer( + new CuDNNConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 4); + EXPECT_EQ(this->blob_top_->height(), 2); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_2_->num(), 2); + EXPECT_EQ(this->blob_top_2_->channels(), 4); + EXPECT_EQ(this->blob_top_2_->height(), 2); + EXPECT_EQ(this->blob_top_2_->width(), 1); + // setting group should not change the shape + convolution_param->set_num_output(3); + convolution_param->set_group(3); + layer.reset(new CuDNNConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3); + EXPECT_EQ(this->blob_top_->height(), 2); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_2_->num(), 2); + EXPECT_EQ(this->blob_top_2_->channels(), 3); + EXPECT_EQ(this->blob_top_2_->height(), 2); + EXPECT_EQ(this->blob_top_2_->width(), 1); +} + +TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) { + Caffe::set_mode(Caffe::GPU); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(4); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new CuDNNConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const TypeParam* top_data; + const TypeParam* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } + caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_2_)); + top_data = this->blob_top_2_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + +TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(3); + convolution_param->set_group(3); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("constant"); + convolution_param->mutable_bias_filler()->set_value(0.1); + shared_ptr > layer( + new CuDNNConvolutionLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Check against reference convolution. + const TypeParam* top_data; + const TypeParam* ref_top_data; + caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(), + this->MakeReferenceTop(this->blob_top_)); + top_data = this->blob_top_->cpu_data(); + ref_top_data = this->ref_blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4); + } +} + +TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) { + // Test separable convolution by computing the Sobel operator + // as a single filter then comparing the result + // as the convolution of two rectangular filters. + Caffe::set_mode(Caffe::GPU); + // Fill bottoms with identical Gaussian noise. + shared_ptr > filler; + FillerParameter filler_param; + filler_param.set_value(1.); + filler.reset(new GaussianFiller(filler_param)); + filler->Fill(this->blob_bottom_); + this->blob_bottom_2_->CopyFrom(*this->blob_bottom_); + // Compute Sobel G_x operator as 3 x 3 convolution. + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + shared_ptr > layer( + new CuDNNConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 3, 3)); + TypeParam* weights = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 9; // 3 x 3 filter + weights[i + 0] = -1; + weights[i + 1] = 0; + weights[i + 2] = 1; + weights[i + 3] = -2; + weights[i + 4] = 0; + weights[i + 5] = 2; + weights[i + 6] = -1; + weights[i + 7] = 0; + weights[i + 8] = 1; + } + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Compute Sobel G_x operator as separable 3 x 1 and 1 x 3 convolutions. + // (1) the [1 2 1] column filter + vector*> sep_blob_bottom_vec; + vector*> sep_blob_top_vec; + shared_ptr > blob_sep(new Blob()); + sep_blob_bottom_vec.push_back(this->blob_bottom_2_); + sep_blob_top_vec.push_back(this->blob_top_2_); + convolution_param->clear_kernel_size(); + convolution_param->clear_stride(); + convolution_param->set_kernel_h(3); + convolution_param->set_kernel_w(1); + convolution_param->set_stride_h(2); + convolution_param->set_stride_w(1); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + layer.reset(new CuDNNConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 3, 1)); + TypeParam* weights_1 = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 3; // 3 x 1 filter + weights_1[i + 0] = 1; + weights_1[i + 1] = 2; + weights_1[i + 2] = 1; + } + layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); + layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); + // (2) the [-1 0 1] row filter + blob_sep->CopyFrom(*this->blob_top_2_, false, true); + sep_blob_bottom_vec.clear(); + sep_blob_bottom_vec.push_back(blob_sep.get()); + convolution_param->set_kernel_h(1); + convolution_param->set_kernel_w(3); + convolution_param->set_stride_h(1); + convolution_param->set_stride_w(2); + convolution_param->set_num_output(1); + convolution_param->set_bias_term(false); + layer.reset(new CuDNNConvolutionLayer(layer_param)); + layer->blobs().resize(1); + layer->blobs()[0].reset(new Blob(1, 3, 1, 3)); + TypeParam* weights_2 = layer->blobs()[0]->mutable_cpu_data(); + for (int c = 0; c < 3; ++c) { + int i = c * 3; // 1 x 3 filter + weights_2[i + 0] = -1; + weights_2[i + 1] = 0; + weights_2[i + 2] = 1; + } + layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec); + layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec); + // Test equivalence of full and separable filters. + const TypeParam* top_data = this->blob_top_->cpu_data(); + const TypeParam* sep_top_data = this->blob_top_2_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4); + } +} + +TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + this->blob_bottom_vec_.push_back(this->blob_bottom_2_); + this->blob_top_vec_.push_back(this->blob_top_2_); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(2); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + CuDNNConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + convolution_param->set_num_output(3); + convolution_param->set_group(3); + convolution_param->mutable_weight_filler()->set_type("gaussian"); + convolution_param->mutable_bias_filler()->set_type("gaussian"); + CuDNNConvolutionLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_data/generate_sample_data.py b/caffe-crfrnn/src/caffe/test/test_data/generate_sample_data.py new file mode 100644 index 00000000..e5dbc340 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_data/generate_sample_data.py @@ -0,0 +1,51 @@ +""" +Generate data used in the HDF5DataLayer test. +""" +import os +import numpy as np +import h5py + +num_cols = 8 +num_rows = 10 +height = 6 +width = 5 +total_size = num_cols * num_rows * height * width + +data = np.arange(total_size) +data = data.reshape(num_rows, num_cols, height, width) +data = data.astype('float32') + +# We had a bug where data was copied into label, but the tests weren't +# catching it, so let's make label 1-indexed. +label = 1 + np.arange(num_rows)[:, np.newaxis] +label = label.astype('float32') + +# We add an extra label2 dataset to test HDF5 layer's ability +# to handle arbitrary number of output ("top") Blobs. +label2 = label + 1 + +print data +print label + +with h5py.File(os.path.dirname(__file__) + '/sample_data.h5', 'w') as f: + f['data'] = data + f['label'] = label + f['label2'] = label2 + +with h5py.File(os.path.dirname(__file__) + '/sample_data_2_gzip.h5', 'w') as f: + f.create_dataset( + 'data', data=data + total_size, + compression='gzip', compression_opts=1 + ) + f.create_dataset( + 'label', data=label, + compression='gzip', compression_opts=1 + ) + f.create_dataset( + 'label2', data=label2, + compression='gzip', compression_opts=1 + ) + +with open(os.path.dirname(__file__) + '/sample_data_list.txt', 'w') as f: + f.write(os.path.dirname(__file__) + '/sample_data.h5\n') + f.write(os.path.dirname(__file__) + '/sample_data_2_gzip.h5\n') diff --git a/caffe-crfrnn/src/caffe/test/test_data/sample_data.h5 b/caffe-crfrnn/src/caffe/test/test_data/sample_data.h5 new file mode 100644 index 00000000..236e66b0 Binary files /dev/null and b/caffe-crfrnn/src/caffe/test/test_data/sample_data.h5 differ diff --git a/caffe-crfrnn/src/caffe/test/test_data/sample_data_2_gzip.h5 b/caffe-crfrnn/src/caffe/test/test_data/sample_data_2_gzip.h5 new file mode 100644 index 00000000..a138e036 Binary files /dev/null and b/caffe-crfrnn/src/caffe/test/test_data/sample_data_2_gzip.h5 differ diff --git a/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt b/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt new file mode 100644 index 00000000..cdf343fc --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt @@ -0,0 +1,2 @@ +src/caffe/test/test_data/sample_data.h5 +src/caffe/test/test_data/sample_data_2_gzip.h5 diff --git a/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt.in b/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt.in new file mode 100644 index 00000000..9860ef58 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_data/sample_data_list.txt.in @@ -0,0 +1,2 @@ +@CMAKE_SOURCE_DIR@/src/caffe/test/test_data/sample_data.h5 +@CMAKE_SOURCE_DIR@/src/caffe/test/test_data/sample_data_2_gzip.h5 \ No newline at end of file diff --git a/caffe-crfrnn/src/caffe/test/test_data_layer.cpp b/caffe-crfrnn/src/caffe/test/test_data_layer.cpp new file mode 100644 index 00000000..32f5d41e --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_data_layer.cpp @@ -0,0 +1,353 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/dataset_factory.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class DataLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + DataLayerTest() + : backend_(DataParameter_DB_LEVELDB), + blob_top_data_(new Blob()), + blob_top_label_(new Blob()), + seed_(1701) {} + virtual void SetUp() { + filename_.reset(new string()); + MakeTempDir(filename_.get()); + *filename_ += "/db"; + blob_top_vec_.push_back(blob_top_data_); + blob_top_vec_.push_back(blob_top_label_); + } + + // Fill the LevelDB with data: if unique_pixels, each pixel is unique but + // all images are the same; else each image is unique but all pixels within + // an image are the same. + void Fill(const bool unique_pixels, DataParameter_DB backend) { + backend_ = backend; + LOG(INFO) << "Using temporary dataset " << *filename_; + shared_ptr > dataset = + DatasetFactory(backend_); + CHECK(dataset->open(*filename_, Dataset::New)); + for (int i = 0; i < 5; ++i) { + Datum datum; + datum.set_label(i); + datum.set_channels(2); + datum.set_height(3); + datum.set_width(4); + std::string* data = datum.mutable_data(); + for (int j = 0; j < 24; ++j) { + int datum = unique_pixels ? j : i; + data->push_back(static_cast(datum)); + } + stringstream ss; + ss << i; + CHECK(dataset->put(ss.str(), datum)); + } + CHECK(dataset->commit()); + dataset->close(); + } + + void TestRead() { + const Dtype scale = 3; + LayerParameter param; + DataParameter* data_param = param.mutable_data_param(); + data_param->set_batch_size(5); + data_param->set_source(filename_->c_str()); + data_param->set_backend(backend_); + + TransformationParameter* transform_param = + param.mutable_transform_param(); + transform_param->set_scale(scale); + + DataLayer layer(param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_data_->num(), 5); + EXPECT_EQ(blob_top_data_->channels(), 2); + EXPECT_EQ(blob_top_data_->height(), 3); + EXPECT_EQ(blob_top_data_->width(), 4); + EXPECT_EQ(blob_top_label_->num(), 5); + EXPECT_EQ(blob_top_label_->channels(), 1); + EXPECT_EQ(blob_top_label_->height(), 1); + EXPECT_EQ(blob_top_label_->width(), 1); + + for (int iter = 0; iter < 100; ++iter) { + layer.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 24; ++j) { + EXPECT_EQ(scale * i, blob_top_data_->cpu_data()[i * 24 + j]) + << "debug: iter " << iter << " i " << i << " j " << j; + } + } + } + } + + void TestReadCrop() { + const Dtype scale = 3; + LayerParameter param; + Caffe::set_random_seed(1701); + + DataParameter* data_param = param.mutable_data_param(); + data_param->set_batch_size(5); + data_param->set_source(filename_->c_str()); + data_param->set_backend(backend_); + + TransformationParameter* transform_param = + param.mutable_transform_param(); + transform_param->set_scale(scale); + transform_param->set_crop_size(1); + + DataLayer layer(param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_data_->num(), 5); + EXPECT_EQ(blob_top_data_->channels(), 2); + EXPECT_EQ(blob_top_data_->height(), 1); + EXPECT_EQ(blob_top_data_->width(), 1); + EXPECT_EQ(blob_top_label_->num(), 5); + EXPECT_EQ(blob_top_label_->channels(), 1); + EXPECT_EQ(blob_top_label_->height(), 1); + EXPECT_EQ(blob_top_label_->width(), 1); + + for (int iter = 0; iter < 2; ++iter) { + layer.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + int num_with_center_value = 0; + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2; ++j) { + const Dtype center_value = scale * (j ? 17 : 5); + num_with_center_value += + (center_value == blob_top_data_->cpu_data()[i * 2 + j]); + // At TEST time, check that we always get center value. + if (Caffe::phase() == Caffe::TEST) { + EXPECT_EQ(center_value, this->blob_top_data_->cpu_data()[i * 2 + j]) + << "debug: iter " << iter << " i " << i << " j " << j; + } + } + } + // At TRAIN time, check that we did not get the center crop all 10 times. + // (This check fails with probability 1-1/12^10 in a correct + // implementation, so we call set_random_seed.) + if (Caffe::phase() == Caffe::TRAIN) { + EXPECT_LT(num_with_center_value, 10); + } + } + } + + void TestReadCropTrainSequenceSeeded() { + LayerParameter param; + DataParameter* data_param = param.mutable_data_param(); + data_param->set_batch_size(5); + data_param->set_source(filename_->c_str()); + data_param->set_backend(backend_); + + TransformationParameter* transform_param = + param.mutable_transform_param(); + transform_param->set_crop_size(1); + transform_param->set_mirror(true); + + // Get crop sequence with Caffe seed 1701. + Caffe::set_random_seed(seed_); + vector > crop_sequence; + { + DataLayer layer1(param); + layer1.SetUp(blob_bottom_vec_, blob_top_vec_); + for (int iter = 0; iter < 2; ++iter) { + layer1.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + vector iter_crop_sequence; + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2; ++j) { + iter_crop_sequence.push_back( + blob_top_data_->cpu_data()[i * 2 + j]); + } + } + crop_sequence.push_back(iter_crop_sequence); + } + } // destroy 1st data layer and unlock the dataset + + // Get crop sequence after reseeding Caffe with 1701. + // Check that the sequence is the same as the original. + Caffe::set_random_seed(seed_); + DataLayer layer2(param); + layer2.SetUp(blob_bottom_vec_, blob_top_vec_); + for (int iter = 0; iter < 2; ++iter) { + layer2.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2; ++j) { + EXPECT_EQ(crop_sequence[iter][i * 2 + j], + blob_top_data_->cpu_data()[i * 2 + j]) + << "debug: iter " << iter << " i " << i << " j " << j; + } + } + } + } + + void TestReadCropTrainSequenceUnseeded() { + LayerParameter param; + DataParameter* data_param = param.mutable_data_param(); + data_param->set_batch_size(5); + data_param->set_source(filename_->c_str()); + data_param->set_backend(backend_); + + TransformationParameter* transform_param = + param.mutable_transform_param(); + transform_param->set_crop_size(1); + transform_param->set_mirror(true); + + // Get crop sequence with Caffe seed 1701, srand seed 1701. + Caffe::set_random_seed(seed_); + srand(seed_); + vector > crop_sequence; + { + DataLayer layer1(param); + layer1.SetUp(blob_bottom_vec_, blob_top_vec_); + for (int iter = 0; iter < 2; ++iter) { + layer1.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + vector iter_crop_sequence; + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2; ++j) { + iter_crop_sequence.push_back( + blob_top_data_->cpu_data()[i * 2 + j]); + } + } + crop_sequence.push_back(iter_crop_sequence); + } + } // destroy 1st data layer and unlock the dataset + + // Get crop sequence continuing from previous Caffe RNG state; reseed + // srand with 1701. Check that the sequence differs from the original. + srand(seed_); + DataLayer layer2(param); + layer2.SetUp(blob_bottom_vec_, blob_top_vec_); + for (int iter = 0; iter < 2; ++iter) { + layer2.Forward(blob_bottom_vec_, blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); + } + int num_sequence_matches = 0; + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2; ++j) { + num_sequence_matches += (crop_sequence[iter][i * 2 + j] == + blob_top_data_->cpu_data()[i * 2 + j]); + } + } + EXPECT_LT(num_sequence_matches, 10); + } + } + + virtual ~DataLayerTest() { delete blob_top_data_; delete blob_top_label_; } + + DataParameter_DB backend_; + shared_ptr filename_; + Blob* const blob_top_data_; + Blob* const blob_top_label_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + int seed_; +}; + +TYPED_TEST_CASE(DataLayerTest, TestDtypesAndDevices); + +TYPED_TEST(DataLayerTest, TestReadLevelDB) { + const bool unique_pixels = false; // all pixels the same; images different + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); + this->TestRead(); +} + +TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); + this->TestReadCrop(); +} + +// Test that the sequence of random crops is consistent when using +// Caffe::set_random_seed. +TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); + this->TestReadCropTrainSequenceSeeded(); +} + +// Test that the sequence of random crops differs across iterations when +// Caffe::set_random_seed isn't called (and seeds from srand are ignored). +TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLevelDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); + this->TestReadCropTrainSequenceUnseeded(); +} + +TYPED_TEST(DataLayerTest, TestReadCropTestLevelDB) { + Caffe::set_phase(Caffe::TEST); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LEVELDB); + this->TestReadCrop(); +} + +TYPED_TEST(DataLayerTest, TestReadLMDB) { + const bool unique_pixels = false; // all pixels the same; images different + this->Fill(unique_pixels, DataParameter_DB_LMDB); + this->TestRead(); +} + +TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LMDB); + this->TestReadCrop(); +} + +// Test that the sequence of random crops is consistent when using +// Caffe::set_random_seed. +TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LMDB); + this->TestReadCropTrainSequenceSeeded(); +} + +// Test that the sequence of random crops differs across iterations when +// Caffe::set_random_seed isn't called (and seeds from srand are ignored). +TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLMDB) { + Caffe::set_phase(Caffe::TRAIN); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LMDB); + this->TestReadCropTrainSequenceUnseeded(); +} + +TYPED_TEST(DataLayerTest, TestReadCropTestLMDB) { + Caffe::set_phase(Caffe::TEST); + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LMDB); + this->TestReadCrop(); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_data_transformer.cpp b/caffe-crfrnn/src/caffe/test/test_data_transformer.cpp new file mode 100644 index 00000000..28c72410 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_data_transformer.cpp @@ -0,0 +1,360 @@ +#include +#include + +#include "gtest/gtest.h" +#include "leveldb/db.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/data_transformer.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +void FillDatum(const int label, const int channels, const int height, + const int width, const bool unique_pixels, Datum * datum) { + datum->set_label(label); + datum->set_channels(channels); + datum->set_height(height); + datum->set_width(width); + int size = channels * height * width; + std::string* data = datum->mutable_data(); + for (int j = 0; j < size; ++j) { + int datum = unique_pixels ? j : label; + data->push_back(static_cast(datum)); + } +} + +template +class DataTransformTest : public ::testing::Test { + protected: + DataTransformTest() + : seed_(1701), + num_iter_(10) {} + + int NumSequenceMatches(const TransformationParameter transform_param, + const Datum& datum) { + // Get crop sequence with Caffe seed 1701. + DataTransformer* transformer = + new DataTransformer(transform_param); + const int crop_size = transform_param.crop_size(); + Caffe::set_random_seed(seed_); + transformer->InitRand(); + Blob* blob = + new Blob(1, datum.channels(), datum.height(), datum.width()); + if (transform_param.crop_size() > 0) { + blob->Reshape(1, datum.channels(), crop_size, crop_size); + } + + vector > crop_sequence; + for (int iter = 0; iter < this->num_iter_; ++iter) { + vector iter_crop_sequence; + transformer->Transform(datum, blob); + for (int j = 0; j < blob->count(); ++j) { + iter_crop_sequence.push_back(blob->cpu_data()[j]); + } + crop_sequence.push_back(iter_crop_sequence); + } + // Check if the sequence differs from the previous + int num_sequence_matches = 0; + for (int iter = 0; iter < this->num_iter_; ++iter) { + vector iter_crop_sequence = crop_sequence[iter]; + transformer->Transform(datum, blob); + for (int j = 0; j < blob->count(); ++j) { + num_sequence_matches += + (crop_sequence[iter][j] == blob->cpu_data()[j]); + } + } + return num_sequence_matches; + } + + virtual ~DataTransformTest() { } + + int seed_; + int num_iter_; +}; + +TYPED_TEST_CASE(DataTransformTest, TestDtypes); + +TYPED_TEST(DataTransformTest, TestEmptyTransform) { + TransformationParameter transform_param; + const bool unique_pixels = false; // all pixels the same equal to label + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Blob* blob = new Blob(1, channels, height, width); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + transformer->Transform(datum, blob); + EXPECT_EQ(blob->num(), 1); + EXPECT_EQ(blob->channels(), datum.channels()); + EXPECT_EQ(blob->height(), datum.height()); + EXPECT_EQ(blob->width(), datum.width()); + for (int j = 0; j < blob->count(); ++j) { + EXPECT_EQ(blob->cpu_data()[j], label); + } +} + +TYPED_TEST(DataTransformTest, TestEmptyTransformUniquePixels) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Blob* blob = new Blob(1, 3, 4, 5); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + transformer->Transform(datum, blob); + EXPECT_EQ(blob->num(), 1); + EXPECT_EQ(blob->channels(), datum.channels()); + EXPECT_EQ(blob->height(), datum.height()); + EXPECT_EQ(blob->width(), datum.width()); + for (int j = 0; j < blob->count(); ++j) { + EXPECT_EQ(blob->cpu_data()[j], j); + } +} + +TYPED_TEST(DataTransformTest, TestCropSize) { + TransformationParameter transform_param; + const bool unique_pixels = false; // all pixels the same equal to label + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int crop_size = 2; + + transform_param.set_crop_size(crop_size); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + Blob* blob = + new Blob(1, channels, crop_size, crop_size); + for (int iter = 0; iter < this->num_iter_; ++iter) { + transformer->Transform(datum, blob); + EXPECT_EQ(blob->num(), 1); + EXPECT_EQ(blob->channels(), datum.channels()); + EXPECT_EQ(blob->height(), crop_size); + EXPECT_EQ(blob->width(), crop_size); + for (int j = 0; j < blob->count(); ++j) { + EXPECT_EQ(blob->cpu_data()[j], label); + } + } +} + +TYPED_TEST(DataTransformTest, TestCropTrain) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int crop_size = 2; + const int size = channels * crop_size * crop_size; + + transform_param.set_crop_size(crop_size); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Caffe::set_phase(Caffe::TRAIN); + int num_matches = this->NumSequenceMatches(transform_param, datum); + EXPECT_LT(num_matches, size * this->num_iter_); +} + +TYPED_TEST(DataTransformTest, TestCropTest) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int crop_size = 2; + const int size = channels * crop_size * crop_size; + + transform_param.set_crop_size(crop_size); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Caffe::set_phase(Caffe::TEST); + int num_matches = this->NumSequenceMatches(transform_param, datum); + EXPECT_EQ(num_matches, size * this->num_iter_); +} + +TYPED_TEST(DataTransformTest, TestMirrorTrain) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int size = channels * height * width; + + transform_param.set_mirror(true); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Caffe::set_phase(Caffe::TRAIN); + int num_matches = this->NumSequenceMatches(transform_param, datum); + EXPECT_LT(num_matches, size * this->num_iter_); +} + +TYPED_TEST(DataTransformTest, TestMirrorTest) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int size = channels * height * width; + + transform_param.set_mirror(true); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Caffe::set_phase(Caffe::TEST); + int num_matches = this->NumSequenceMatches(transform_param, datum); + EXPECT_LT(num_matches, size * this->num_iter_); +} + +TYPED_TEST(DataTransformTest, TestCropMirrorTrain) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int crop_size = 2; + + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + transform_param.set_crop_size(crop_size); + Caffe::set_phase(Caffe::TRAIN); + int num_matches_crop = this->NumSequenceMatches(transform_param, datum); + + transform_param.set_mirror(true); + int num_matches_crop_mirror = + this->NumSequenceMatches(transform_param, datum); + // When doing crop and mirror we expect less num_matches than just crop + EXPECT_LE(num_matches_crop_mirror, num_matches_crop); +} + +TYPED_TEST(DataTransformTest, TestCropMirrorTest) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int crop_size = 2; + + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + transform_param.set_crop_size(crop_size); + Caffe::set_phase(Caffe::TEST); + int num_matches_crop = this->NumSequenceMatches(transform_param, datum); + + transform_param.set_mirror(true); + int num_matches_crop_mirror = + this->NumSequenceMatches(transform_param, datum); + // When doing crop and mirror we expect less num_matches than just crop + EXPECT_LT(num_matches_crop_mirror, num_matches_crop); +} + + +TYPED_TEST(DataTransformTest, TestMeanValue) { + TransformationParameter transform_param; + const bool unique_pixels = false; // pixels are equal to label + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int mean_value = 2; + + transform_param.add_mean_value(mean_value); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Blob* blob = new Blob(1, channels, height, width); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + transformer->Transform(datum, blob); + for (int j = 0; j < blob->count(); ++j) { + EXPECT_EQ(blob->cpu_data()[j], label - mean_value); + } +} + +TYPED_TEST(DataTransformTest, TestMeanValues) { + TransformationParameter transform_param; + const bool unique_pixels = false; // pixels are equal to label + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + + transform_param.add_mean_value(0); + transform_param.add_mean_value(1); + transform_param.add_mean_value(2); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Blob* blob = new Blob(1, channels, height, width); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + transformer->Transform(datum, blob); + for (int c = 0; c < channels; ++c) { + for (int j = 0; j < height * width; ++j) { + EXPECT_EQ(blob->cpu_data()[blob->offset(0, c) + j], label - c); + } + } +} + +TYPED_TEST(DataTransformTest, TestMeanFile) { + TransformationParameter transform_param; + const bool unique_pixels = true; // pixels are consecutive ints [0,size] + const int label = 0; + const int channels = 3; + const int height = 4; + const int width = 5; + const int size = channels * height * width; + + // Create a mean file + string* mean_file = new string(); + MakeTempFilename(mean_file); + BlobProto blob_mean; + blob_mean.set_num(1); + blob_mean.set_channels(channels); + blob_mean.set_height(height); + blob_mean.set_width(width); + + for (int j = 0; j < size; ++j) { + blob_mean.add_data(j); + } + + LOG(INFO) << "Using temporary mean_file " << *mean_file; + WriteProtoToBinaryFile(blob_mean, *mean_file); + + transform_param.set_mean_file(*mean_file); + Datum datum; + FillDatum(label, channels, height, width, unique_pixels, &datum); + Blob* blob = new Blob(1, channels, height, width); + DataTransformer* transformer = + new DataTransformer(transform_param); + transformer->InitRand(); + transformer->Transform(datum, blob); + for (int j = 0; j < blob->count(); ++j) { + EXPECT_EQ(blob->cpu_data()[j], 0); + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_dataset.cpp b/caffe-crfrnn/src/caffe/test/test_dataset.cpp new file mode 100644 index 00000000..6645ca22 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_dataset.cpp @@ -0,0 +1,794 @@ +#include +#include + +#include "caffe/util/io.hpp" + +#include "gtest/gtest.h" + +#include "caffe/dataset_factory.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +namespace DatasetTest_internal { + +template +struct TestData { + static T TestValue(); + static T TestAltValue(); + static bool equals(const T& a, const T& b); +}; + +template <> +string TestData::TestValue() { + return "world"; +} + +template <> +string TestData::TestAltValue() { + return "bar"; +} + +template <> +bool TestData::equals(const string& a, const string& b) { + return a == b; +} + +template <> +vector TestData >::TestValue() { + string str = "world"; + vector val(str.data(), str.data() + str.size()); + return val; +} + +template <> +vector TestData >::TestAltValue() { + string str = "bar"; + vector val(str.data(), str.data() + str.size()); + return val; +} + +template <> +bool TestData >::equals(const vector& a, + const vector& b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < a.size(); ++i) { + if (a.at(i) != b.at(i)) { + return false; + } + } + + return true; +} + +template <> +Datum TestData::TestValue() { + Datum datum; + datum.set_channels(3); + datum.set_height(32); + datum.set_width(32); + datum.set_data(string(32 * 32 * 3 * 4, ' ')); + datum.set_label(0); + return datum; +} + +template <> +Datum TestData::TestAltValue() { + Datum datum; + datum.set_channels(1); + datum.set_height(64); + datum.set_width(64); + datum.set_data(string(64 * 64 * 1 * 4, ' ')); + datum.set_label(1); + return datum; +} + +template <> +bool TestData::equals(const Datum& a, const Datum& b) { + string serialized_a; + a.SerializeToString(&serialized_a); + + string serialized_b; + b.SerializeToString(&serialized_b); + + return serialized_a == serialized_b; +} + +} // namespace DatasetTest_internal + +#define UNPACK_TYPES \ + typedef typename TypeParam::value_type value_type; \ + const DataParameter_DB backend = TypeParam::backend; + +template +class DatasetTest : public ::testing::Test { + protected: + typedef typename TypeParam::value_type value_type; + + string DBName() { + string filename; + MakeTempDir(&filename); + filename += "/db"; + return filename; + } + + string TestKey() { + return "hello"; + } + + value_type TestValue() { + return DatasetTest_internal::TestData::TestValue(); + } + + string TestAltKey() { + return "foo"; + } + + value_type TestAltValue() { + return DatasetTest_internal::TestData::TestAltValue(); + } + + template + bool equals(const T& a, const T& b) { + return DatasetTest_internal::TestData::equals(a, b); + } +}; + +struct StringLeveldb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLeveldb::backend = DataParameter_DB_LEVELDB; + +struct StringLmdb { + typedef string value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB StringLmdb::backend = DataParameter_DB_LMDB; + +struct VectorLeveldb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLeveldb::backend = DataParameter_DB_LEVELDB; + +struct VectorLmdb { + typedef vector value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB VectorLmdb::backend = DataParameter_DB_LMDB; + +struct DatumLeveldb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLeveldb::backend = DataParameter_DB_LEVELDB; + +struct DatumLmdb { + typedef Datum value_type; + static const DataParameter_DB backend; +}; +const DataParameter_DB DatumLmdb::backend = DataParameter_DB_LMDB; + +typedef ::testing::Types TestTypes; + +TYPED_TEST_CASE(DatasetTest, TestTypes); + +TYPED_TEST(DatasetTest, TestNewDoesntExistPasses) { + UNPACK_TYPES; + + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(this->DBName(), + Dataset::New)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewExistsFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_FALSE(dataset->open(name, Dataset::New)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyExistsPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadOnlyDoesntExistFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_FALSE(dataset->open(name, Dataset::ReadOnly)); +} + +TYPED_TEST(DatasetTest, TestReadWriteExistsPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteDoesntExistPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestKeys) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(dataset->put(key2, value2)); + + EXPECT_TRUE(dataset->commit()); + + vector keys; + dataset->keys(&keys); + + EXPECT_EQ(2, keys.size()); + + EXPECT_TRUE(this->equals(keys.at(0), key1) || + this->equals(keys.at(0), key2)); + EXPECT_TRUE(this->equals(keys.at(1), key1) || + this->equals(keys.at(2), key2)); + EXPECT_FALSE(this->equals(keys.at(0), keys.at(1))); +} + +TYPED_TEST(DatasetTest, TestFirstKey) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + value_type value = this->TestValue(); + + string key1 = "01"; + EXPECT_TRUE(dataset->put(key1, value)); + + string key2 = "02"; + EXPECT_TRUE(dataset->put(key2, value)); + + string key3 = "03"; + EXPECT_TRUE(dataset->put(key3, value)); + + EXPECT_TRUE(dataset->commit()); + + string first_key; + dataset->first_key(&first_key); + + EXPECT_TRUE(this->equals(first_key, key1)); +} + +TYPED_TEST(DatasetTest, TestLastKey) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + value_type value = this->TestValue(); + + string key1 = "01"; + EXPECT_TRUE(dataset->put(key1, value)); + + string key2 = "02"; + EXPECT_TRUE(dataset->put(key2, value)); + + string key3 = "03"; + EXPECT_TRUE(dataset->put(key3, value)); + + EXPECT_TRUE(dataset->commit()); + + string last_key; + dataset->last_key(&last_key); + + EXPECT_TRUE(this->equals(last_key, key3)); +} + +TYPED_TEST(DatasetTest, TestFirstLastKeys) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + value_type value = this->TestValue(); + + string key1 = "01"; + EXPECT_TRUE(dataset->put(key1, value)); + + string key2 = "02"; + EXPECT_TRUE(dataset->put(key2, value)); + + string key3 = "03"; + EXPECT_TRUE(dataset->put(key3, value)); + + EXPECT_TRUE(dataset->commit()); + + string first_key; + dataset->first_key(&first_key); + string last_key; + dataset->last_key(&last_key); + + EXPECT_TRUE(this->equals(first_key, key1)); + EXPECT_TRUE(this->equals(last_key, key3)); +} + +TYPED_TEST(DatasetTest, TestFirstLastKeysUnOrdered) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + value_type value = this->TestValue(); + + string key3 = "03"; + EXPECT_TRUE(dataset->put(key3, value)); + + string key1 = "01"; + EXPECT_TRUE(dataset->put(key1, value)); + + string key2 = "02"; + EXPECT_TRUE(dataset->put(key2, value)); + + EXPECT_TRUE(dataset->commit()); + + string first_key; + dataset->first_key(&first_key); + string last_key; + dataset->last_key(&last_key); + + EXPECT_TRUE(this->equals(first_key, key1)); + EXPECT_TRUE(this->equals(last_key, key3)); +} + +TYPED_TEST(DatasetTest, TestKeysNoCommit) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestKey(); + value_type value1 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + + string key2 = this->TestAltKey(); + value_type value2 = this->TestAltValue(); + + EXPECT_TRUE(dataset->put(key2, value2)); + + vector keys; + dataset->keys(&keys); + + EXPECT_EQ(0, keys.size()); +} + +TYPED_TEST(DatasetTest, TestIterators) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + const int kNumExamples = 4; + for (int i = 0; i < kNumExamples; ++i) { + stringstream ss; + ss << i; + string key = ss.str(); + ss << " here be data"; + value_type value = this->TestValue(); + EXPECT_TRUE(dataset->put(key, value)); + } + EXPECT_TRUE(dataset->commit()); + + int count = 0; + typedef typename Dataset::const_iterator Iter; + for (Iter iter = dataset->begin(); iter != dataset->end(); ++iter) { + (void)iter; + ++count; + } + + EXPECT_EQ(kNumExamples, count); +} + +TYPED_TEST(DatasetTest, TestIteratorsPreIncrement) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); + + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + EXPECT_TRUE(dataset->put(key2, value2)); + EXPECT_TRUE(dataset->commit()); + + typename Dataset::const_iterator iter1 = + dataset->begin(); + + EXPECT_FALSE(dataset->end() == iter1); + + EXPECT_TRUE(this->equals(iter1->key, key1)); + + typename Dataset::const_iterator iter2 = ++iter1; + + EXPECT_FALSE(dataset->end() == iter1); + EXPECT_FALSE(dataset->end() == iter2); + + EXPECT_TRUE(this->equals(iter2->key, key2)); + + typename Dataset::const_iterator iter3 = ++iter2; + + EXPECT_TRUE(dataset->end() == iter3); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestIteratorsPostIncrement) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key1 = this->TestAltKey(); + value_type value1 = this->TestAltValue(); + + string key2 = this->TestKey(); + value_type value2 = this->TestValue(); + + EXPECT_TRUE(dataset->put(key1, value1)); + EXPECT_TRUE(dataset->put(key2, value2)); + EXPECT_TRUE(dataset->commit()); + + typename Dataset::const_iterator iter1 = + dataset->begin(); + + EXPECT_FALSE(dataset->end() == iter1); + + EXPECT_TRUE(this->equals(iter1->key, key1)); + + typename Dataset::const_iterator iter2 = iter1++; + + EXPECT_FALSE(dataset->end() == iter1); + EXPECT_FALSE(dataset->end() == iter2); + + EXPECT_TRUE(this->equals(iter2->key, key1)); + EXPECT_TRUE(this->equals(iter1->key, key2)); + + typename Dataset::const_iterator iter3 = iter1++; + + EXPECT_FALSE(dataset->end() == iter3); + EXPECT_TRUE(this->equals(iter3->key, key2)); + EXPECT_TRUE(dataset->end() == iter1); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewPutPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewCommitPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestNewGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + + +TYPED_TEST(DatasetTest, TestReadWritePutPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteCommitPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::ReadWrite)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); + + dataset->close(); +} + +TYPED_TEST(DatasetTest, TestReadWriteGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyPutFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_FALSE(dataset->put(key, value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + EXPECT_FALSE(dataset->commit()); +} + +TYPED_TEST(DatasetTest, TestReadOnlyGetPasses) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + EXPECT_TRUE(dataset->commit()); + + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + value_type new_value; + + EXPECT_TRUE(dataset->get(key, &new_value)); + + EXPECT_TRUE(this->equals(value, new_value)); +} + +TYPED_TEST(DatasetTest, TestReadOnlyGetNoCommitFails) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + + EXPECT_TRUE(dataset->put(key, value)); + + dataset->close(); + + EXPECT_TRUE(dataset->open(name, Dataset::ReadOnly)); + + value_type new_value; + + EXPECT_FALSE(dataset->get(key, &new_value)); +} + +TYPED_TEST(DatasetTest, TestCreateManyItersShortScope) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + EXPECT_TRUE(dataset->put(key, value)); + EXPECT_TRUE(dataset->commit()); + + for (int i = 0; i < 1000; ++i) { + typename Dataset::const_iterator iter = + dataset->begin(); + } +} + +TYPED_TEST(DatasetTest, TestCreateManyItersLongScope) { + UNPACK_TYPES; + + string name = this->DBName(); + shared_ptr > dataset = + DatasetFactory(backend); + EXPECT_TRUE(dataset->open(name, Dataset::New)); + + string key = this->TestKey(); + value_type value = this->TestValue(); + EXPECT_TRUE(dataset->put(key, value)); + EXPECT_TRUE(dataset->commit()); + + vector::const_iterator> iters; + for (int i = 0; i < 1000; ++i) { + iters.push_back(dataset->begin()); + } +} + +#undef UNPACK_TYPES + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_dummy_data_layer.cpp b/caffe-crfrnn/src/caffe/test/test_dummy_data_layer.cpp new file mode 100644 index 00000000..99548352 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_dummy_data_layer.cpp @@ -0,0 +1,196 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class DummyDataLayerTest : public ::testing::Test { + protected: + DummyDataLayerTest() + : blob_top_a_(new Blob()), + blob_top_b_(new Blob()), + blob_top_c_(new Blob()) {} + + virtual void SetUp() { + blob_bottom_vec_.clear(); + blob_top_vec_.clear(); + blob_top_vec_.push_back(blob_top_a_); + blob_top_vec_.push_back(blob_top_b_); + blob_top_vec_.push_back(blob_top_c_); + } + + virtual ~DummyDataLayerTest() { + delete blob_top_a_; + delete blob_top_b_; + delete blob_top_c_; + } + + Blob* const blob_top_a_; + Blob* const blob_top_b_; + Blob* const blob_top_c_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(DummyDataLayerTest, TestDtypes); + +TYPED_TEST(DummyDataLayerTest, TestOneTopConstant) { + Caffe::set_mode(Caffe::CPU); + LayerParameter param; + DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param(); + dummy_data_param->add_num(5); + dummy_data_param->add_channels(3); + dummy_data_param->add_height(2); + dummy_data_param->add_width(4); + this->blob_top_vec_.resize(1); + DummyDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_a_->num(), 5); + EXPECT_EQ(this->blob_top_a_->channels(), 3); + EXPECT_EQ(this->blob_top_a_->height(), 2); + EXPECT_EQ(this->blob_top_a_->width(), 4); + EXPECT_EQ(this->blob_top_b_->count(), 0); + EXPECT_EQ(this->blob_top_c_->count(), 0); + for (int i = 0; i < this->blob_top_vec_.size(); ++i) { + for (int j = 0; j < this->blob_top_vec_[i]->count(); ++j) { + EXPECT_EQ(0, this->blob_top_vec_[i]->cpu_data()[j]); + } + } + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_vec_.size(); ++i) { + for (int j = 0; j < this->blob_top_vec_[i]->count(); ++j) { + EXPECT_EQ(0, this->blob_top_vec_[i]->cpu_data()[j]); + } + } +} + +TYPED_TEST(DummyDataLayerTest, TestTwoTopConstant) { + Caffe::set_mode(Caffe::CPU); + LayerParameter param; + DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param(); + dummy_data_param->add_num(5); + dummy_data_param->add_channels(3); + dummy_data_param->add_height(2); + dummy_data_param->add_width(4); + dummy_data_param->add_num(5); + // Don't explicitly set number of channels or height for 2nd top blob; should + // default to first channels and height (as we check later). + dummy_data_param->add_height(1); + FillerParameter* data_filler_param = dummy_data_param->add_data_filler(); + data_filler_param->set_value(7); + this->blob_top_vec_.resize(2); + DummyDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_a_->num(), 5); + EXPECT_EQ(this->blob_top_a_->channels(), 3); + EXPECT_EQ(this->blob_top_a_->height(), 2); + EXPECT_EQ(this->blob_top_a_->width(), 4); + EXPECT_EQ(this->blob_top_b_->num(), 5); + EXPECT_EQ(this->blob_top_b_->channels(), 3); + EXPECT_EQ(this->blob_top_b_->height(), 1); + EXPECT_EQ(this->blob_top_b_->width(), 4); + EXPECT_EQ(this->blob_top_c_->count(), 0); + for (int i = 0; i < this->blob_top_vec_.size(); ++i) { + for (int j = 0; j < this->blob_top_vec_[i]->count(); ++j) { + EXPECT_EQ(7, this->blob_top_vec_[i]->cpu_data()[j]); + } + } + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_vec_.size(); ++i) { + for (int j = 0; j < this->blob_top_vec_[i]->count(); ++j) { + EXPECT_EQ(7, this->blob_top_vec_[i]->cpu_data()[j]); + } + } +} + +TYPED_TEST(DummyDataLayerTest, TestThreeTopConstantGaussianConstant) { + Caffe::set_mode(Caffe::CPU); + LayerParameter param; + DummyDataParameter* dummy_data_param = param.mutable_dummy_data_param(); + dummy_data_param->add_num(5); + dummy_data_param->add_channels(3); + dummy_data_param->add_height(2); + dummy_data_param->add_width(4); + FillerParameter* data_filler_param_a = dummy_data_param->add_data_filler(); + data_filler_param_a->set_value(7); + FillerParameter* data_filler_param_b = dummy_data_param->add_data_filler(); + data_filler_param_b->set_type("gaussian"); + TypeParam gaussian_mean = 3.0; + TypeParam gaussian_std = 0.01; + data_filler_param_b->set_mean(gaussian_mean); + data_filler_param_b->set_std(gaussian_std); + FillerParameter* data_filler_param_c = dummy_data_param->add_data_filler(); + data_filler_param_c->set_value(9); + DummyDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_a_->num(), 5); + EXPECT_EQ(this->blob_top_a_->channels(), 3); + EXPECT_EQ(this->blob_top_a_->height(), 2); + EXPECT_EQ(this->blob_top_a_->width(), 4); + EXPECT_EQ(this->blob_top_b_->num(), 5); + EXPECT_EQ(this->blob_top_b_->channels(), 3); + EXPECT_EQ(this->blob_top_b_->height(), 2); + EXPECT_EQ(this->blob_top_b_->width(), 4); + EXPECT_EQ(this->blob_top_c_->num(), 5); + EXPECT_EQ(this->blob_top_c_->channels(), 3); + EXPECT_EQ(this->blob_top_c_->height(), 2); + EXPECT_EQ(this->blob_top_c_->width(), 4); + for (int i = 0; i < this->blob_top_a_->count(); ++i) { + EXPECT_EQ(7, this->blob_top_a_->cpu_data()[i]); + } + // Blob b uses a Gaussian filler, so SetUp should not have initialized it. + // Blob b's data should therefore be the default Blob data value: 0. + for (int i = 0; i < this->blob_top_b_->count(); ++i) { + EXPECT_EQ(0, this->blob_top_b_->cpu_data()[i]); + } + for (int i = 0; i < this->blob_top_c_->count(); ++i) { + EXPECT_EQ(9, this->blob_top_c_->cpu_data()[i]); + } + + // Do a Forward pass to fill in Blob b with Gaussian data. + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_a_->count(); ++i) { + EXPECT_EQ(7, this->blob_top_a_->cpu_data()[i]); + } + // Check that the Gaussian's data has been filled in with values within + // 10 standard deviations of the mean. Record the first and last sample. + // to check that they're different after the next Forward pass. + for (int i = 0; i < this->blob_top_b_->count(); ++i) { + EXPECT_NEAR(gaussian_mean, this->blob_top_b_->cpu_data()[i], + gaussian_std * 10); + } + const TypeParam first_gaussian_sample = this->blob_top_b_->cpu_data()[0]; + const TypeParam last_gaussian_sample = + this->blob_top_b_->cpu_data()[this->blob_top_b_->count() - 1]; + for (int i = 0; i < this->blob_top_c_->count(); ++i) { + EXPECT_EQ(9, this->blob_top_c_->cpu_data()[i]); + } + + // Do another Forward pass to fill in Blob b with Gaussian data again, + // checking that we get different values. + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_a_->count(); ++i) { + EXPECT_EQ(7, this->blob_top_a_->cpu_data()[i]); + } + for (int i = 0; i < this->blob_top_b_->count(); ++i) { + EXPECT_NEAR(gaussian_mean, this->blob_top_b_->cpu_data()[i], + gaussian_std * 10); + } + EXPECT_NE(first_gaussian_sample, this->blob_top_b_->cpu_data()[0]); + EXPECT_NE(last_gaussian_sample, + this->blob_top_b_->cpu_data()[this->blob_top_b_->count() - 1]); + for (int i = 0; i < this->blob_top_c_->count(); ++i) { + EXPECT_EQ(9, this->blob_top_c_->cpu_data()[i]); + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_eltwise_layer.cpp b/caffe-crfrnn/src/caffe/test/test_eltwise_layer.cpp new file mode 100644 index 00000000..be0c1347 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_eltwise_layer.cpp @@ -0,0 +1,209 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class EltwiseLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + EltwiseLayerTest() + : blob_bottom_a_(new Blob(2, 3, 4, 5)), + blob_bottom_b_(new Blob(2, 3, 4, 5)), + blob_bottom_c_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + // fill the values + Caffe::set_random_seed(1701); + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_a_); + filler.Fill(this->blob_bottom_b_); + filler.Fill(this->blob_bottom_c_); + blob_bottom_vec_.push_back(blob_bottom_a_); + blob_bottom_vec_.push_back(blob_bottom_b_); + blob_bottom_vec_.push_back(blob_bottom_c_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~EltwiseLayerTest() { + delete blob_bottom_a_; + delete blob_bottom_b_; + delete blob_bottom_c_; + delete blob_top_; + } + Blob* const blob_bottom_a_; + Blob* const blob_bottom_b_; + Blob* const blob_bottom_c_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(EltwiseLayerTest, TestDtypesAndDevices); + +TYPED_TEST(EltwiseLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + shared_ptr > layer( + new EltwiseLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3); + EXPECT_EQ(this->blob_top_->height(), 4); + EXPECT_EQ(this->blob_top_->width(), 5); +} + +TYPED_TEST(EltwiseLayerTest, TestProd) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + shared_ptr > layer( + new EltwiseLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); + const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(data[i], in_data_a[i] * in_data_b[i] * in_data_c[i]); + } +} + +TYPED_TEST(EltwiseLayerTest, TestSum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_SUM); + shared_ptr > layer( + new EltwiseLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); + const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(data[i], in_data_a[i] + in_data_b[i] + in_data_c[i]); + } +} + +TYPED_TEST(EltwiseLayerTest, TestSumCoeff) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_SUM); + eltwise_param->add_coeff(1); + eltwise_param->add_coeff(-0.5); + eltwise_param->add_coeff(2); + shared_ptr > layer( + new EltwiseLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); + const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_NEAR(data[i], in_data_a[i] - 0.5*in_data_b[i] + 2*in_data_c[i], + 1e-4); + } +} + +TYPED_TEST(EltwiseLayerTest, TestStableProdGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + eltwise_param->set_stable_prod_grad(true); + EltwiseLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(EltwiseLayerTest, TestUnstableProdGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD); + eltwise_param->set_stable_prod_grad(false); + EltwiseLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(EltwiseLayerTest, TestSumGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_SUM); + EltwiseLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(EltwiseLayerTest, TestSumCoeffGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_SUM); + eltwise_param->add_coeff(1); + eltwise_param->add_coeff(-0.5); + eltwise_param->add_coeff(2); + EltwiseLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(EltwiseLayerTest, TestMax) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_MAX); + shared_ptr > layer( + new EltwiseLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + const Dtype* in_data_a = this->blob_bottom_a_->cpu_data(); + const Dtype* in_data_b = this->blob_bottom_b_->cpu_data(); + const Dtype* in_data_c = this->blob_bottom_c_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(data[i], + std::max(in_data_a[i], std::max(in_data_b[i], in_data_c[i]))); + } +} + +TYPED_TEST(EltwiseLayerTest, TestMaxGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EltwiseParameter* eltwise_param = layer_param.mutable_eltwise_param(); + eltwise_param->set_operation(EltwiseParameter_EltwiseOp_MAX); + EltwiseLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_euclidean_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_euclidean_loss_layer.cpp new file mode 100644 index 00000000..1949742b --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_euclidean_loss_layer.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class EuclideanLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + EuclideanLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 5, 1, 1)), + blob_top_loss_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + filler.Fill(this->blob_bottom_label_); + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~EuclideanLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_loss_; + } + + void TestForward() { + // Get the loss without a specified objective weight -- should be + // equivalent to explicitly specifiying a weight of 1. + LayerParameter layer_param; + EuclideanLossLayer layer_weight_1(layer_param); + layer_weight_1.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_weight_1.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Get the loss again with a different objective weight; check that it is + // scaled appropriately. + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + EuclideanLossLayer layer_weight_2(layer_param); + layer_weight_2.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype loss_weight_2 = + layer_weight_2.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype kErrorMargin = 1e-5; + EXPECT_NEAR(loss_weight_1 * kLossWeight, loss_weight_2, kErrorMargin); + // Make sure the loss is non-trivial. + const Dtype kNonTrivialAbsThresh = 1e-1; + EXPECT_GE(fabs(loss_weight_1), kNonTrivialAbsThresh); + } + + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(EuclideanLossLayerTest, TestDtypesAndDevices); + +TYPED_TEST(EuclideanLossLayerTest, TestForward) { + this->TestForward(); +} + +TYPED_TEST(EuclideanLossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + EuclideanLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_filler.cpp b/caffe-crfrnn/src/caffe/test/test_filler.cpp new file mode 100644 index 00000000..e04b0fd2 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_filler.cpp @@ -0,0 +1,145 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class ConstantFillerTest : public ::testing::Test { + protected: + ConstantFillerTest() + : blob_(new Blob(2, 3, 4, 5)), + filler_param_() { + filler_param_.set_value(10.); + filler_.reset(new ConstantFiller(filler_param_)); + filler_->Fill(blob_); + } + virtual ~ConstantFillerTest() { delete blob_; } + Blob* const blob_; + FillerParameter filler_param_; + shared_ptr > filler_; +}; + +TYPED_TEST_CASE(ConstantFillerTest, TestDtypes); + +TYPED_TEST(ConstantFillerTest, TestFill) { + EXPECT_TRUE(this->blob_); + const int count = this->blob_->count(); + const TypeParam* data = this->blob_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_GE(data[i], this->filler_param_.value()); + } +} + + +template +class UniformFillerTest : public ::testing::Test { + protected: + UniformFillerTest() + : blob_(new Blob(2, 3, 4, 5)), + filler_param_() { + filler_param_.set_min(1.); + filler_param_.set_max(2.); + filler_.reset(new UniformFiller(filler_param_)); + filler_->Fill(blob_); + } + virtual ~UniformFillerTest() { delete blob_; } + Blob* const blob_; + FillerParameter filler_param_; + shared_ptr > filler_; +}; + +TYPED_TEST_CASE(UniformFillerTest, TestDtypes); + +TYPED_TEST(UniformFillerTest, TestFill) { + EXPECT_TRUE(this->blob_); + const int count = this->blob_->count(); + const TypeParam* data = this->blob_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_GE(data[i], this->filler_param_.min()); + EXPECT_LE(data[i], this->filler_param_.max()); + } +} + +template +class PositiveUnitballFillerTest : public ::testing::Test { + protected: + PositiveUnitballFillerTest() + : blob_(new Blob(2, 3, 4, 5)), + filler_param_() { + filler_.reset(new PositiveUnitballFiller(filler_param_)); + filler_->Fill(blob_); + } + virtual ~PositiveUnitballFillerTest() { delete blob_; } + Blob* const blob_; + FillerParameter filler_param_; + shared_ptr > filler_; +}; + +TYPED_TEST_CASE(PositiveUnitballFillerTest, TestDtypes); + +TYPED_TEST(PositiveUnitballFillerTest, TestFill) { + EXPECT_TRUE(this->blob_); + const int num = this->blob_->num(); + const int count = this->blob_->count(); + const int dim = count / num; + const TypeParam* data = this->blob_->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_GE(data[i], 0); + EXPECT_LE(data[i], 1); + } + for (int i = 0; i < num; ++i) { + TypeParam sum = 0; + for (int j = 0; j < dim; ++j) { + sum += data[i * dim + j]; + } + EXPECT_GE(sum, 0.999); + EXPECT_LE(sum, 1.001); + } +} + +template +class GaussianFillerTest : public ::testing::Test { + protected: + GaussianFillerTest() + : blob_(new Blob(2, 3, 4, 5)), + filler_param_() { + filler_param_.set_mean(10.); + filler_param_.set_std(0.1); + filler_.reset(new GaussianFiller(filler_param_)); + filler_->Fill(blob_); + } + virtual ~GaussianFillerTest() { delete blob_; } + Blob* const blob_; + FillerParameter filler_param_; + shared_ptr > filler_; +}; + +TYPED_TEST_CASE(GaussianFillerTest, TestDtypes); + +TYPED_TEST(GaussianFillerTest, TestFill) { + EXPECT_TRUE(this->blob_); + const int count = this->blob_->count(); + const TypeParam* data = this->blob_->cpu_data(); + TypeParam mean = 0.; + TypeParam var = 0.; + for (int i = 0; i < count; ++i) { + mean += data[i]; + var += (data[i] - this->filler_param_.mean()) * + (data[i] - this->filler_param_.mean()); + } + mean /= count; + var /= count; + // Very loose test. + EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 5); + EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 5); + TypeParam target_var = this->filler_param_.std() * this->filler_param_.std(); + EXPECT_GE(var, target_var / 5.); + EXPECT_LE(var, target_var * 5.); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_flatten_layer.cpp b/caffe-crfrnn/src/caffe/test/test_flatten_layer.cpp new file mode 100644 index 00000000..3042d293 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_flatten_layer.cpp @@ -0,0 +1,75 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class FlattenLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + FlattenLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~FlattenLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(FlattenLayerTest, TestDtypesAndDevices); + +TYPED_TEST(FlattenLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + FlattenLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 3 * 6 * 5); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); +} + +TYPED_TEST(FlattenLayerTest, Test) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + FlattenLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int c = 0; c < 3 * 6 * 5; ++c) { + EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0), + this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5)); + EXPECT_EQ(this->blob_top_->data_at(1, c, 0, 0), + this->blob_bottom_->data_at(1, c / (6 * 5), (c / 5) % 6, c % 5)); + } +} + +TYPED_TEST(FlattenLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + FlattenLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_gradient_based_solver.cpp b/caffe-crfrnn/src/caffe/test/test_gradient_based_solver.cpp new file mode 100644 index 00000000..65de52aa --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_gradient_based_solver.cpp @@ -0,0 +1,485 @@ +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::ostringstream; + +namespace caffe { + +template +class GradientBasedSolverTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + GradientBasedSolverTest() : + seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} + + shared_ptr > solver_; + int seed_; + int num_, channels_, height_, width_; + Dtype delta_; // Stability constant for AdaGrad. + + virtual SolverParameter_SolverType solver_type() = 0; + virtual void InitSolver(const SolverParameter& param) = 0; + + virtual void InitSolverFromProtoString(const string& proto) { + SolverParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + // Disable saving a final snapshot so the tests don't pollute the user's + // working directory with useless snapshots. + param.set_snapshot_after_train(false); + // Set the solver_mode according to current Caffe::mode. + switch (Caffe::mode()) { + case Caffe::CPU: + param.set_solver_mode(SolverParameter_SolverMode_CPU); + break; + case Caffe::GPU: + param.set_solver_mode(SolverParameter_SolverMode_GPU); + break; + default: + LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); + } + InitSolver(param); + delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ? + param.delta() : 0; + } + + void RunLeastSquaresSolver(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, const int num_iters) { + ostringstream proto; + proto << + "max_iter: " << num_iters << " " + "base_lr: " << learning_rate << " " + "lr_policy: 'fixed' " + "net_param { " + " name: 'TestNetwork' " + " layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: " << num_ << " " + " channels: " << channels_ << " " + " height: " << height_ << " " + " width: " << width_ << " " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " top: 'data' " + " top: 'targets' " + " } " + " layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " bias_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " bottom: 'data' " + " top: 'innerprod' " + " } " + " layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod' " + " bottom: 'targets' " + " } " + "} "; + if (weight_decay != 0) { + proto << "weight_decay: " << weight_decay << " "; + } + if (momentum != 0) { + proto << "momentum: " << momentum << " "; + } + Caffe::set_random_seed(this->seed_); + this->InitSolverFromProtoString(proto.str()); + this->solver_->Solve(); + } + + // Compute an update value given the current state of the train net, + // using the analytical formula for the least squares gradient. + // updated_params will store the updated weight and bias results, + // using the blobs' diffs to hold the update values themselves. + void ComputeLeastSquaresUpdate(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, + vector > >* updated_params) { + const int N = num_; + const int D = channels_ * height_ * width_; + + // Run a forward pass, and manually compute the update values from the + // result. + Net& net = *this->solver_->net(); + vector*> empty_bottom_vec; + net.Forward(empty_bottom_vec); + ASSERT_TRUE(net.has_blob("data")); + const Blob& data = *net.blob_by_name("data"); + ASSERT_TRUE(net.has_blob("targets")); + const Blob& targets = *net.blob_by_name("targets"); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + const int num_param_blobs = 2; + ASSERT_EQ(num_param_blobs, param_blobs.size()); + const Blob& weights = *param_blobs[0]; + const Blob& bias = *param_blobs[1]; + ASSERT_EQ(D * N, data.count()); + ASSERT_EQ(N, targets.count()); + ASSERT_EQ(D, weights.count()); + ASSERT_EQ(1, bias.count()); + + updated_params->clear(); + updated_params->resize(num_param_blobs); + for (int i = 0; i < num_param_blobs; ++i) { + (*updated_params)[i].reset(new Blob()); + } + Blob& updated_weights = *(*updated_params)[0]; + updated_weights.ReshapeLike(weights); + Blob& updated_bias = *(*updated_params)[1]; + updated_bias.ReshapeLike(bias); + + for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith + // element of the gradient). + Dtype grad = 0; + for (int j = 0; j <= D; ++j) { + // Compute element (i, j) of X^T * X. + Dtype element = 0; + for (int k = 0; k < N; ++k) { + // (i, k) in X^T (== (k, i) in X) times (k, j) in X. + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; + element += element_i * element_j; + } + if (j == D) { + grad += element * bias.cpu_data()[0]; + } else { + grad += element * weights.cpu_data()[j]; + } + } + for (int k = 0; k < N; ++k) { + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + grad -= element_i * targets.cpu_data()[k]; + } + // Scale the gradient over the N samples. + grad /= N; + // Add the weight decay to the gradient. + grad += weight_decay * + ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); + // Finally, compute update. + const vector > >& history = solver_->history(); + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + Dtype update_value = learning_rate * grad; + const Dtype history_value = (i == D) ? + history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; + const Dtype temp = momentum * history_value; + switch (solver_type()) { + case SolverParameter_SolverType_SGD: + update_value += temp; + break; + case SolverParameter_SolverType_NESTEROV: + update_value += temp; + // step back then over-step + update_value = (1 + momentum) * update_value - temp; + break; + case SolverParameter_SolverType_ADAGRAD: + update_value /= std::sqrt(history_value + grad * grad) + delta_; + break; + default: + LOG(FATAL) << "Unknown solver type: " << solver_type(); + } + if (i == D) { + updated_bias.mutable_cpu_diff()[0] = update_value; + updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; + } else { + updated_weights.mutable_cpu_diff()[i] = update_value; + updated_weights.mutable_cpu_data()[i] = + weights.cpu_data()[i] - update_value; + } + } + } + + void CheckLeastSquaresUpdate( + const vector > >& updated_params) { + const int D = channels_ * height_ * width_; + + const Blob& updated_weights = *updated_params[0]; + const Blob& updated_bias = *updated_params[1]; + + Net& net = *this->solver_->net(); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + ASSERT_EQ(2, param_blobs.size()); + const Blob& solver_updated_weights = *param_blobs[0]; + ASSERT_EQ(D, solver_updated_weights.count()); + const double kPrecision = 1e-2; + const double kMinPrecision = 1e-7; + for (int i = 0; i < D; ++i) { + const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; + const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); + EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); + } + const Blob& solver_updated_bias_blob = *param_blobs[1]; + ASSERT_EQ(1, solver_updated_bias_blob.count()); + const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; + const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); + EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); + + // Check the solver's history -- should contain the previous update value. + if (solver_type() == SolverParameter_SolverType_SGD) { + const vector > >& history = solver_->history(); + ASSERT_EQ(2, history.size()); + for (int i = 0; i < D; ++i) { + const Dtype expected_history = updated_weights.cpu_diff()[i]; + const Dtype solver_history = history[0]->cpu_data()[i]; + const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_history), fabs(solver_history))); + EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + const Dtype expected_history = updated_bias.cpu_diff()[0]; + const Dtype solver_history = history[1]->cpu_data()[0]; + const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_history), fabs(solver_history))); + EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + } + + // Test that the correct update is computed for a regularized least squares + // problem: + // + // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 + // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w + // + // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) + // w \in R^{(d+1) x 1} ((d+1)th element is the bias) + // y \in R^{n x 1} + // lambda is weight_decay + // + // TestLeastSquaresUpdate works "inductively", assuming that the solver + // correctly updates the net K (= iter_to_check) times, then given the history + // from the Kth update, we compute the (K+1)th update and check that it + // matches the solver's (K+1)th update. + void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, + const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, + const int iter_to_check = 0) { + // Initialize the solver and run K (= iter_to_check) solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } +}; + + +template +class SGDSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new SGDSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_SGD; + } +}; + +TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) { + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + + +template +class AdaGradSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaGradSolver(param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADAGRAD; + } +}; + +TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdate) { + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.0; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + + +template +class NesterovSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new NesterovSolver(param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_NESTEROV; + } +}; + +TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdate) { + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_hdf5_output_layer.cpp b/caffe-crfrnn/src/caffe/test/test_hdf5_output_layer.cpp new file mode 100644 index 00000000..2e8f0969 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_hdf5_output_layer.cpp @@ -0,0 +1,120 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class HDF5OutputLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + HDF5OutputLayerTest() + : input_file_name_( + CMAKE_SOURCE_DIR "caffe/test/test_data/sample_data.h5"), + blob_data_(new Blob()), + blob_label_(new Blob()), + num_(5), + channels_(8), + height_(5), + width_(5) { + MakeTempFilename(&output_file_name_); + } + + virtual ~HDF5OutputLayerTest() { + delete blob_data_; + delete blob_label_; + } + + void CheckBlobEqual(const Blob& b1, const Blob& b2); + + string output_file_name_; + string input_file_name_; + Blob* const blob_data_; + Blob* const blob_label_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + int num_; + int channels_; + int height_; + int width_; +}; + +template +void HDF5OutputLayerTest::CheckBlobEqual(const Blob& b1, + const Blob& b2) { + EXPECT_EQ(b1.num(), b2.num()); + EXPECT_EQ(b1.channels(), b2.channels()); + EXPECT_EQ(b1.height(), b2.height()); + EXPECT_EQ(b1.width(), b2.width()); + for (int n = 0; n < b1.num(); ++n) { + for (int c = 0; c < b1.channels(); ++c) { + for (int h = 0; h < b1.height(); ++h) { + for (int w = 0; w < b1.width(); ++w) { + EXPECT_EQ(b1.data_at(n, c, h, w), b2.data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST_CASE(HDF5OutputLayerTest, TestDtypesAndDevices); + +TYPED_TEST(HDF5OutputLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LOG(INFO) << "Loading HDF5 file " << this->input_file_name_; + hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY, + H5P_DEFAULT); + ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" << + this->input_file_name_; + hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4, + this->blob_data_); + hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4, + this->blob_label_); + herr_t status = H5Fclose(file_id); + EXPECT_GE(status, 0)<< "Failed to close HDF5 file " << + this->input_file_name_; + this->blob_bottom_vec_.push_back(this->blob_data_); + this->blob_bottom_vec_.push_back(this->blob_label_); + + LayerParameter param; + param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_); + // This code block ensures that the layer is deconstructed and + // the output hdf5 file is closed. + { + HDF5OutputLayer layer(param); + EXPECT_EQ(layer.file_name(), this->output_file_name_); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + } + file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY, + H5P_DEFAULT); + ASSERT_GE( + file_id, 0)<< "Failed to open HDF5 file" << + this->input_file_name_; + + Blob* blob_data = new Blob(); + hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4, + blob_data); + this->CheckBlobEqual(*(this->blob_data_), *blob_data); + + Blob* blob_label = new Blob(); + hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4, + blob_label); + this->CheckBlobEqual(*(this->blob_label_), *blob_label); + + status = H5Fclose(file_id); + EXPECT_GE(status, 0) << "Failed to close HDF5 file " << + this->output_file_name_; +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_hdf5data_layer.cpp b/caffe-crfrnn/src/caffe/test/test_hdf5data_layer.cpp new file mode 100644 index 00000000..8d3b3d1e --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_hdf5data_layer.cpp @@ -0,0 +1,137 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class HDF5DataLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + HDF5DataLayerTest() + : filename(NULL), + blob_top_data_(new Blob()), + blob_top_label_(new Blob()), + blob_top_label2_(new Blob()) {} + virtual void SetUp() { + blob_top_vec_.push_back(blob_top_data_); + blob_top_vec_.push_back(blob_top_label_); + blob_top_vec_.push_back(blob_top_label2_); + + // Check out generate_sample_data.py in the same directory. + filename = new string( + CMAKE_SOURCE_DIR "caffe/test/test_data/sample_data_list.txt" CMAKE_EXT); + LOG(INFO)<< "Using sample HDF5 data file " << filename; + } + + virtual ~HDF5DataLayerTest() { + delete blob_top_data_; + delete blob_top_label_; + delete blob_top_label2_; + delete filename; + } + + string* filename; + Blob* const blob_top_data_; + Blob* const blob_top_label_; + Blob* const blob_top_label2_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(HDF5DataLayerTest, TestDtypesAndDevices); + +TYPED_TEST(HDF5DataLayerTest, TestRead) { + typedef typename TypeParam::Dtype Dtype; + // Create LayerParameter with the known parameters. + // The data file we are reading has 10 rows and 8 columns, + // with values from 0 to 10*8 reshaped in row-major order. + LayerParameter param; + param.add_top("data"); + param.add_top("label"); + param.add_top("label2"); + + HDF5DataParameter* hdf5_data_param = param.mutable_hdf5_data_param(); + int batch_size = 5; + hdf5_data_param->set_batch_size(batch_size); + hdf5_data_param->set_source(*(this->filename)); + int num_cols = 8; + int height = 6; + int width = 5; + + // Test that the layer setup got the correct parameters. + HDF5DataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_data_->num(), batch_size); + EXPECT_EQ(this->blob_top_data_->channels(), num_cols); + EXPECT_EQ(this->blob_top_data_->height(), height); + EXPECT_EQ(this->blob_top_data_->width(), width); + + EXPECT_EQ(this->blob_top_label_->num(), batch_size); + EXPECT_EQ(this->blob_top_label_->channels(), 1); + EXPECT_EQ(this->blob_top_label_->height(), 1); + EXPECT_EQ(this->blob_top_label_->width(), 1); + + EXPECT_EQ(this->blob_top_label2_->num(), batch_size); + EXPECT_EQ(this->blob_top_label2_->channels(), 1); + EXPECT_EQ(this->blob_top_label2_->height(), 1); + EXPECT_EQ(this->blob_top_label2_->width(), 1); + + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + + // Go through the data 10 times (5 batches). + const int data_size = num_cols * height * width; + for (int iter = 0; iter < 10; ++iter) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // On even iterations, we're reading the first half of the data. + // On odd iterations, we're reading the second half of the data. + // NB: label is 1-indexed + int label_offset = 1 + ((iter % 2 == 0) ? 0 : batch_size); + int label2_offset = 1 + label_offset; + int data_offset = (iter % 2 == 0) ? 0 : batch_size * data_size; + + // Every two iterations we are reading the second file, + // which has the same labels, but data is offset by total data size, + // which is 2400 (see generate_sample_data). + int file_offset = (iter % 4 < 2) ? 0 : 2400; + + for (int i = 0; i < batch_size; ++i) { + EXPECT_EQ( + label_offset + i, + this->blob_top_label_->cpu_data()[i]); + EXPECT_EQ( + label2_offset + i, + this->blob_top_label2_->cpu_data()[i]); + } + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_cols; ++j) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + int idx = ( + i * num_cols * height * width + + j * height * width + + h * width + w); + EXPECT_EQ( + file_offset + data_offset + idx, + this->blob_top_data_->cpu_data()[idx]) + << "debug: i " << i << " j " << j + << " iter " << iter; + } + } + } + } + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_hinge_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_hinge_loss_layer.cpp new file mode 100644 index 00000000..b6a99022 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_hinge_loss_layer.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class HingeLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + HingeLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 1, 1, 1)), + blob_top_loss_(new Blob()) { + // fill the values + Caffe::set_random_seed(1701); + FillerParameter filler_param; + filler_param.set_std(10); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + for (int i = 0; i < blob_bottom_label_->count(); ++i) { + blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5; + } + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~HingeLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_loss_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(HingeLossLayerTest, TestDtypesAndDevices); + + +TYPED_TEST(HingeLossLayerTest, TestGradientL1) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + HingeLossLayer layer(layer_param); + GradientChecker checker(1e-2, 2e-3, 1701, 1, 0.01); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +TYPED_TEST(HingeLossLayerTest, TestGradientL2) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + // Set norm to L2 + HingeLossParameter* hinge_loss_param = layer_param.mutable_hinge_loss_param(); + hinge_loss_param->set_norm(HingeLossParameter_Norm_L2); + HingeLossLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_im2col_kernel.cu b/caffe-crfrnn/src/caffe/test/test_im2col_kernel.cu new file mode 100644 index 00000000..ee684c00 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_im2col_kernel.cu @@ -0,0 +1,127 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +// Forward declare kernel functions +template +__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + Dtype* data_col); + +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; + +template +class Im2colKernelTest : public ::testing::Test { + protected: + Im2colKernelTest() + // big so launches > 1024 threads + : blob_bottom_(new Blob(5, 500, 10, 10)), + blob_top_(new Blob()), + blob_top_cpu_(new Blob()) { + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + + height_ = blob_bottom_->height(); + width_ = blob_bottom_->width(); + channels_ = blob_bottom_->channels(); + pad_ = 0; + stride_ = 2; + kernel_size_ = 3; + height_col_ = (height_ + 2 * pad_ - kernel_size_) / stride_ + 1; + width_col_ = (width_ + 2 * pad_ - kernel_size_) / stride_ + 1; + } + + virtual ~Im2colKernelTest() { + delete blob_bottom_; + delete blob_top_; + delete blob_top_cpu_; + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + Blob* const blob_top_cpu_; + int height_; + int width_; + int channels_; + int pad_; + int stride_; + int kernel_size_; + int height_col_; + int width_col_; +}; + +TYPED_TEST_CASE(Im2colKernelTest, TestDtypes); + +TYPED_TEST(Im2colKernelTest, TestGPU) { + Caffe::set_mode(Caffe::GPU); + + // Reshape the blobs to correct size for im2col output + this->blob_top_->Reshape(this->blob_bottom_->num(), + this->channels_ * this->kernel_size_ * this->kernel_size_, + this->height_col_, + this->width_col_); + + this->blob_top_cpu_->Reshape(this->blob_bottom_->num(), + this->channels_ * this->kernel_size_ * this->kernel_size_, + this->height_col_, + this->width_col_); + + const TypeParam* bottom_data = this->blob_bottom_->gpu_data(); + TypeParam* top_data = this->blob_top_->mutable_gpu_data(); + TypeParam* cpu_data = this->blob_top_cpu_->mutable_cpu_data(); + + // CPU Version + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + im2col_cpu(this->blob_bottom_->cpu_data() + this->blob_bottom_->offset(n), + this->channels_, this->height_, this->width_, + this->kernel_size_, this->kernel_size_, this->pad_, this->pad_, + this->stride_, this->stride_, + cpu_data + this->blob_top_cpu_->offset(n)); + } + + // GPU version + int num_kernels = this->channels_ * this->height_col_ * this->width_col_; + int default_grid_dim = CAFFE_GET_BLOCKS(num_kernels); + + // Launch with different grid sizes + for (int grid_div = 2; grid_div <= 8; grid_div++) { + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + int grid_dim = default_grid_dim/grid_div; + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_gpu_kernel<<>>( + num_kernels, bottom_data + this->blob_bottom_->offset(n), + this->height_, this->width_, this->kernel_size_, this->kernel_size_, + this->pad_, this->pad_, this->stride_, this->stride_, + this->height_col_, this->width_col_, + top_data + this->blob_top_->offset(n)); + CUDA_POST_KERNEL_CHECK; + } + + // Compare results against CPU version + for (int i = 0; i < this->blob_top_->count(); ++i) { + TypeParam cpuval = cpu_data[i]; + TypeParam gpuval = this->blob_top_->cpu_data()[i]; + EXPECT_EQ(cpuval, gpuval); + if (cpuval != gpuval) { + break; + } + } + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_im2col_layer.cpp b/caffe-crfrnn/src/caffe/test/test_im2col_layer.cpp new file mode 100644 index 00000000..f50abe10 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_im2col_layer.cpp @@ -0,0 +1,118 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class Im2colLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + Im2colLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 5)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~Im2colLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(Im2colLayerTest, TestDtypesAndDevices); + +TYPED_TEST(Im2colLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + Im2colLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 27); + EXPECT_EQ(this->blob_top_->height(), 2); + EXPECT_EQ(this->blob_top_->width(), 2); +} + +TYPED_TEST(Im2colLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + Im2colLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // We are lazy and will only check the top left block + for (int c = 0; c < 27; ++c) { + EXPECT_EQ(this->blob_bottom_->data_at(0, (c / 9), (c / 3) % 3, c % 3), + this->blob_top_->data_at(0, c, 0, 0)); + } +} + +TYPED_TEST(Im2colLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_size(3); + convolution_param->set_stride(2); + Im2colLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + +TYPED_TEST(Im2colLayerTest, TestRect) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_h(5); + convolution_param->set_kernel_w(3); + convolution_param->set_stride(2); + Im2colLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // We are lazy and will only check the top left block + for (int c = 0; c < 45; ++c) { + EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0), + this->blob_bottom_->data_at(0, (c / 15), (c / 3) % 5, c % 3)); + } +} + + +TYPED_TEST(Im2colLayerTest, TestRectGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ConvolutionParameter* convolution_param = + layer_param.mutable_convolution_param(); + convolution_param->set_kernel_h(5); + convolution_param->set_kernel_w(3); + convolution_param->set_stride(2); + Im2colLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_image_data_layer.cpp b/caffe-crfrnn/src/caffe/test/test_image_data_layer.cpp new file mode 100644 index 00000000..77523ef8 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_image_data_layer.cpp @@ -0,0 +1,144 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class ImageDataLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + ImageDataLayerTest() + : seed_(1701), + blob_top_data_(new Blob()), + blob_top_label_(new Blob()) {} + virtual void SetUp() { + MakeTempFilename(&filename_); + blob_top_vec_.push_back(blob_top_data_); + blob_top_vec_.push_back(blob_top_label_); + Caffe::set_random_seed(seed_); + // Create a Vector of files with labels + std::ofstream outfile(filename_.c_str(), std::ofstream::out); + LOG(INFO) << "Using temporary file " << filename_; + for (int i = 0; i < 5; ++i) { + outfile << EXAMPLES_SOURCE_DIR "images/cat.jpg " << i; + } + outfile.close(); + } + + virtual ~ImageDataLayerTest() { + delete blob_top_data_; + delete blob_top_label_; + } + + int seed_; + string filename_; + Blob* const blob_top_data_; + Blob* const blob_top_label_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ImageDataLayerTest, TestDtypesAndDevices); + +TYPED_TEST(ImageDataLayerTest, TestRead) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter param; + ImageDataParameter* image_data_param = param.mutable_image_data_param(); + image_data_param->set_batch_size(5); + image_data_param->set_source(this->filename_.c_str()); + image_data_param->set_shuffle(false); + ImageDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_data_->num(), 5); + EXPECT_EQ(this->blob_top_data_->channels(), 3); + EXPECT_EQ(this->blob_top_data_->height(), 360); + EXPECT_EQ(this->blob_top_data_->width(), 480); + EXPECT_EQ(this->blob_top_label_->num(), 5); + EXPECT_EQ(this->blob_top_label_->channels(), 1); + EXPECT_EQ(this->blob_top_label_->height(), 1); + EXPECT_EQ(this->blob_top_label_->width(), 1); + // Go through the data twice + for (int iter = 0; iter < 2; ++iter) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]); + } + } +} + +TYPED_TEST(ImageDataLayerTest, TestResize) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter param; + ImageDataParameter* image_data_param = param.mutable_image_data_param(); + image_data_param->set_batch_size(5); + image_data_param->set_source(this->filename_.c_str()); + image_data_param->set_new_height(256); + image_data_param->set_new_width(256); + image_data_param->set_shuffle(false); + ImageDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_data_->num(), 5); + EXPECT_EQ(this->blob_top_data_->channels(), 3); + EXPECT_EQ(this->blob_top_data_->height(), 256); + EXPECT_EQ(this->blob_top_data_->width(), 256); + EXPECT_EQ(this->blob_top_label_->num(), 5); + EXPECT_EQ(this->blob_top_label_->channels(), 1); + EXPECT_EQ(this->blob_top_label_->height(), 1); + EXPECT_EQ(this->blob_top_label_->width(), 1); + // Go through the data twice + for (int iter = 0; iter < 2; ++iter) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(i, this->blob_top_label_->cpu_data()[i]); + } + } +} + +TYPED_TEST(ImageDataLayerTest, TestShuffle) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter param; + ImageDataParameter* image_data_param = param.mutable_image_data_param(); + image_data_param->set_batch_size(5); + image_data_param->set_source(this->filename_.c_str()); + image_data_param->set_shuffle(true); + ImageDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_data_->num(), 5); + EXPECT_EQ(this->blob_top_data_->channels(), 3); + EXPECT_EQ(this->blob_top_data_->height(), 360); + EXPECT_EQ(this->blob_top_data_->width(), 480); + EXPECT_EQ(this->blob_top_label_->num(), 5); + EXPECT_EQ(this->blob_top_label_->channels(), 1); + EXPECT_EQ(this->blob_top_label_->height(), 1); + EXPECT_EQ(this->blob_top_label_->width(), 1); + // Go through the data twice + for (int iter = 0; iter < 2; ++iter) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + map values_to_indices; + int num_in_order = 0; + for (int i = 0; i < 5; ++i) { + Dtype value = this->blob_top_label_->cpu_data()[i]; + // Check that the value has not been seen already (no duplicates). + EXPECT_EQ(values_to_indices.find(value), values_to_indices.end()); + values_to_indices[value] = i; + num_in_order += (value == Dtype(i)); + } + EXPECT_EQ(5, values_to_indices.size()); + EXPECT_GT(5, num_in_order); + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_infogain_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_infogain_loss_layer.cpp new file mode 100644 index 00000000..7ec2f807 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_infogain_loss_layer.cpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/loss_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class InfogainLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + InfogainLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 1, 1, 1)), + blob_bottom_infogain_(new Blob(1, 1, 5, 5)), + blob_top_loss_(new Blob()) { + Caffe::set_random_seed(1701); + FillerParameter filler_param; + PositiveUnitballFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + for (int i = 0; i < blob_bottom_label_->count(); ++i) { + blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5; + } + blob_bottom_vec_.push_back(blob_bottom_label_); + filler_param.set_min(0.1); + filler_param.set_max(2.0); + UniformFiller infogain_filler(filler_param); + infogain_filler.Fill(this->blob_bottom_infogain_); + blob_bottom_vec_.push_back(blob_bottom_infogain_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~InfogainLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_bottom_infogain_; + delete blob_top_loss_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_bottom_infogain_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(InfogainLossLayerTest, TestDtypesAndDevices); + + +TYPED_TEST(InfogainLossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + InfogainLossLayer layer(layer_param); + GradientChecker checker(1e-4, 2e-2, 1701, 1, 0.01); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_inner_product_layer.cpp b/caffe-crfrnn/src/caffe/test/test_inner_product_layer.cpp new file mode 100644 index 00000000..c03df173 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_inner_product_layer.cpp @@ -0,0 +1,113 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif + +template +class InnerProductLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + InnerProductLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~InnerProductLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(InnerProductLayerTest, TestDtypesAndDevices); + +TYPED_TEST(InnerProductLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); + EXPECT_EQ(this->blob_top_->channels(), 10); +} + +TYPED_TEST(InnerProductLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + bool IS_VALID_CUDA = false; +#ifndef CPU_ONLY + IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2; +#endif + if (Caffe::mode() == Caffe::CPU || + sizeof(Dtype) == 4 || IS_VALID_CUDA) { + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + inner_product_param->mutable_weight_filler()->set_type("uniform"); + inner_product_param->mutable_bias_filler()->set_type("uniform"); + inner_product_param->mutable_bias_filler()->set_min(1); + inner_product_param->mutable_bias_filler()->set_max(2); + shared_ptr > layer( + new InnerProductLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->blob_top_->cpu_data(); + const int count = this->blob_top_->count(); + for (int i = 0; i < count; ++i) { + EXPECT_GE(data[i], 1.); + } + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + +TYPED_TEST(InnerProductLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + bool IS_VALID_CUDA = false; +#ifndef CPU_ONLY + IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2; +#endif + if (Caffe::mode() == Caffe::CPU || + sizeof(Dtype) == 4 || IS_VALID_CUDA) { + LayerParameter layer_param; + InnerProductParameter* inner_product_param = + layer_param.mutable_inner_product_param(); + inner_product_param->set_num_output(10); + inner_product_param->mutable_weight_filler()->set_type("gaussian"); + inner_product_param->mutable_bias_filler()->set_type("gaussian"); + inner_product_param->mutable_bias_filler()->set_min(1); + inner_product_param->mutable_bias_filler()->set_max(2); + InnerProductLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_internal_thread.cpp b/caffe-crfrnn/src/caffe/test/test_internal_thread.cpp new file mode 100644 index 00000000..31882b6d --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_internal_thread.cpp @@ -0,0 +1,23 @@ +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "caffe/internal_thread.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + + +class InternalThreadTest : public ::testing::Test {}; + +TEST_F(InternalThreadTest, TestStartAndExit) { + InternalThread thread; + EXPECT_FALSE(thread.is_started()); + EXPECT_TRUE(thread.StartInternalThread()); + EXPECT_TRUE(thread.is_started()); + EXPECT_TRUE(thread.WaitForInternalThreadToExit()); + EXPECT_FALSE(thread.is_started()); +} + +} // namespace caffe + diff --git a/caffe-crfrnn/src/caffe/test/test_io.cpp b/caffe-crfrnn/src/caffe/test/test_io.cpp new file mode 100644 index 00000000..4d941fa8 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_io.cpp @@ -0,0 +1,359 @@ +#include +#include +#include +#include + +#include + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/util/io.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +class IOTest : public ::testing::Test {}; + +bool ReadImageToDatumReference(const string& filename, const int label, + const int height, const int width, const bool is_color, Datum* datum) { + cv::Mat cv_img; + int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : + CV_LOAD_IMAGE_GRAYSCALE); + + cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag); + if (!cv_img_origin.data) { + LOG(ERROR) << "Could not open or find file " << filename; + return false; + } + if (height > 0 && width > 0) { + cv::resize(cv_img_origin, cv_img, cv::Size(width, height)); + } else { + cv_img = cv_img_origin; + } + + int num_channels = (is_color ? 3 : 1); + datum->set_channels(num_channels); + datum->set_height(cv_img.rows); + datum->set_width(cv_img.cols); + datum->set_label(label); + datum->clear_data(); + datum->clear_float_data(); + string* datum_string = datum->mutable_data(); + if (is_color) { + for (int c = 0; c < num_channels; ++c) { + for (int h = 0; h < cv_img.rows; ++h) { + for (int w = 0; w < cv_img.cols; ++w) { + datum_string->push_back( + static_cast(cv_img.at(h, w)[c])); + } + } + } + } else { // Faster than repeatedly testing is_color for each pixel w/i loop + for (int h = 0; h < cv_img.rows; ++h) { + for (int w = 0; w < cv_img.cols; ++w) { + datum_string->push_back( + static_cast(cv_img.at(h, w))); + } + } + } + return true; +} + +TEST_F(IOTest, TestReadImageToDatum) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + ReadImageToDatum(filename, 0, &datum); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 360); + EXPECT_EQ(datum.width(), 480); +} + +TEST_F(IOTest, TestReadImageToDatumReference) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum, datum_ref; + ReadImageToDatum(filename, 0, 0, 0, true, &datum); + ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref); + EXPECT_EQ(datum.channels(), datum_ref.channels()); + EXPECT_EQ(datum.height(), datum_ref.height()); + EXPECT_EQ(datum.width(), datum_ref.width()); + EXPECT_EQ(datum.data().size(), datum_ref.data().size()); + + const string& data = datum.data(); + const string& data_ref = datum.data(); + + for (int i = 0; i < datum.data().size(); ++i) { + EXPECT_TRUE(data[i] == data_ref[i]); + } +} + + +TEST_F(IOTest, TestReadImageToDatumReferenceResized) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum, datum_ref; + ReadImageToDatum(filename, 0, 100, 200, true, &datum); + ReadImageToDatumReference(filename, 0, 100, 200, true, &datum_ref); + EXPECT_EQ(datum.channels(), datum_ref.channels()); + EXPECT_EQ(datum.height(), datum_ref.height()); + EXPECT_EQ(datum.width(), datum_ref.width()); + EXPECT_EQ(datum.data().size(), datum_ref.data().size()); + + const string& data = datum.data(); + const string& data_ref = datum.data(); + + for (int i = 0; i < datum.data().size(); ++i) { + EXPECT_TRUE(data[i] == data_ref[i]); + } +} + +TEST_F(IOTest, TestReadImageToDatumContent) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + ReadImageToDatum(filename, 0, &datum); + cv::Mat cv_img = ReadImageToCVMat(filename); + EXPECT_EQ(datum.channels(), cv_img.channels()); + EXPECT_EQ(datum.height(), cv_img.rows); + EXPECT_EQ(datum.width(), cv_img.cols); + + const string& data = datum.data(); + int index = 0; + for (int c = 0; c < datum.channels(); ++c) { + for (int h = 0; h < datum.height(); ++h) { + for (int w = 0; w < datum.width(); ++w) { + EXPECT_TRUE(data[index++] == + static_cast(cv_img.at(h, w)[c])); + } + } + } +} + +TEST_F(IOTest, TestReadImageToDatumContentGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + const bool is_color = false; + ReadImageToDatum(filename, 0, is_color, &datum); + cv::Mat cv_img = ReadImageToCVMat(filename, is_color); + EXPECT_EQ(datum.channels(), cv_img.channels()); + EXPECT_EQ(datum.height(), cv_img.rows); + EXPECT_EQ(datum.width(), cv_img.cols); + + const string& data = datum.data(); + int index = 0; + for (int h = 0; h < datum.height(); ++h) { + for (int w = 0; w < datum.width(); ++w) { + EXPECT_TRUE(data[index++] == static_cast(cv_img.at(h, w))); + } + } +} + +TEST_F(IOTest, TestReadImageToDatumResized) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + ReadImageToDatum(filename, 0, 100, 200, &datum); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 100); + EXPECT_EQ(datum.width(), 200); +} + + +TEST_F(IOTest, TestReadImageToDatumResizedSquare) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + ReadImageToDatum(filename, 0, 256, 256, &datum); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 256); + EXPECT_EQ(datum.width(), 256); +} + +TEST_F(IOTest, TestReadImageToDatumGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + const bool is_color = false; + ReadImageToDatum(filename, 0, is_color, &datum); + EXPECT_EQ(datum.channels(), 1); + EXPECT_EQ(datum.height(), 360); + EXPECT_EQ(datum.width(), 480); +} + +TEST_F(IOTest, TestReadImageToDatumResizedGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + const bool is_color = false; + ReadImageToDatum(filename, 0, 256, 256, is_color, &datum); + EXPECT_EQ(datum.channels(), 1); + EXPECT_EQ(datum.height(), 256); + EXPECT_EQ(datum.width(), 256); +} + +TEST_F(IOTest, TestReadImageToCVMat) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename); + EXPECT_EQ(cv_img.channels(), 3); + EXPECT_EQ(cv_img.rows, 360); + EXPECT_EQ(cv_img.cols, 480); +} + +TEST_F(IOTest, TestReadImageToCVMatResized) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename, 100, 200); + EXPECT_EQ(cv_img.channels(), 3); + EXPECT_EQ(cv_img.rows, 100); + EXPECT_EQ(cv_img.cols, 200); +} + +TEST_F(IOTest, TestReadImageToCVMatResizedSquare) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename, 256, 256); + EXPECT_EQ(cv_img.channels(), 3); + EXPECT_EQ(cv_img.rows, 256); + EXPECT_EQ(cv_img.cols, 256); +} + +TEST_F(IOTest, TestReadImageToCVMatGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + const bool is_color = false; + cv::Mat cv_img = ReadImageToCVMat(filename, is_color); + EXPECT_EQ(cv_img.channels(), 1); + EXPECT_EQ(cv_img.rows, 360); + EXPECT_EQ(cv_img.cols, 480); +} + +TEST_F(IOTest, TestReadImageToCVMatResizedGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + const bool is_color = false; + cv::Mat cv_img = ReadImageToCVMat(filename, 256, 256, is_color); + EXPECT_EQ(cv_img.channels(), 1); + EXPECT_EQ(cv_img.rows, 256); + EXPECT_EQ(cv_img.cols, 256); +} + +TEST_F(IOTest, TestCVMatToDatum) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename); + Datum datum; + CVMatToDatum(cv_img, &datum); + EXPECT_EQ(datum.channels(), 3); + EXPECT_EQ(datum.height(), 360); + EXPECT_EQ(datum.width(), 480); +} + +TEST_F(IOTest, TestCVMatToDatumContent) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename); + Datum datum; + CVMatToDatum(cv_img, &datum); + Datum datum_ref; + ReadImageToDatum(filename, 0, &datum_ref); + EXPECT_EQ(datum.channels(), datum_ref.channels()); + EXPECT_EQ(datum.height(), datum_ref.height()); + EXPECT_EQ(datum.width(), datum_ref.width()); + EXPECT_EQ(datum.data().size(), datum_ref.data().size()); + + const string& data = datum.data(); + const string& data_ref = datum_ref.data(); + for (int i = 0; i < datum.data().size(); ++i) { + EXPECT_TRUE(data[i] == data_ref[i]); + } +} + +TEST_F(IOTest, TestCVMatToDatumReference) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + cv::Mat cv_img = ReadImageToCVMat(filename); + Datum datum; + CVMatToDatum(cv_img, &datum); + Datum datum_ref; + ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref); + EXPECT_EQ(datum.channels(), datum_ref.channels()); + EXPECT_EQ(datum.height(), datum_ref.height()); + EXPECT_EQ(datum.width(), datum_ref.width()); + EXPECT_EQ(datum.data().size(), datum_ref.data().size()); + + const string& data = datum.data(); + const string& data_ref = datum_ref.data(); + for (int i = 0; i < datum.data().size(); ++i) { + EXPECT_TRUE(data[i] == data_ref[i]); + } +} + +TEST_F(IOTest, TestReadFileToDatum) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + EXPECT_TRUE(datum.encoded()); + EXPECT_EQ(datum.label(), -1); + EXPECT_EQ(datum.data().size(), 140391); +} + +TEST_F(IOTest, TestDecodeDatum) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + EXPECT_TRUE(DecodeDatum(&datum)); + EXPECT_FALSE(DecodeDatum(&datum)); + Datum datum_ref; + ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref); + EXPECT_EQ(datum.channels(), datum_ref.channels()); + EXPECT_EQ(datum.height(), datum_ref.height()); + EXPECT_EQ(datum.width(), datum_ref.width()); + EXPECT_EQ(datum.data().size(), datum_ref.data().size()); + + const string& data = datum.data(); + const string& data_ref = datum_ref.data(); + for (int i = 0; i < datum.data().size(); ++i) { + EXPECT_TRUE(data[i] == data_ref[i]); + } +} + +TEST_F(IOTest, TestDecodeDatumToCVMat) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + cv::Mat cv_img = DecodeDatumToCVMat(datum); + EXPECT_EQ(cv_img.channels(), 3); + EXPECT_EQ(cv_img.rows, 360); + EXPECT_EQ(cv_img.cols, 480); +} + +TEST_F(IOTest, TestDecodeDatumToCVMatResized) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + cv::Mat cv_img = DecodeDatumToCVMat(datum, 100, 200); + EXPECT_EQ(cv_img.channels(), 3); + EXPECT_EQ(cv_img.rows, 100); + EXPECT_EQ(cv_img.cols, 200); +} + +TEST_F(IOTest, TestDecodeDatumToCVMatResizedGray) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + const bool is_color = false; + cv::Mat cv_img = DecodeDatumToCVMat(datum, 200, 100, is_color); + EXPECT_EQ(cv_img.channels(), 1); + EXPECT_EQ(cv_img.rows, 200); + EXPECT_EQ(cv_img.cols, 100); +} + +TEST_F(IOTest, TestDecodeDatumToCVMatContent) { + string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg"; + Datum datum; + EXPECT_TRUE(ReadFileToDatum(filename, &datum)); + cv::Mat cv_img = DecodeDatumToCVMat(datum); + cv::Mat cv_img_ref = ReadImageToCVMat(filename); + EXPECT_EQ(cv_img_ref.channels(), cv_img.channels()); + EXPECT_EQ(cv_img_ref.rows, cv_img.rows); + EXPECT_EQ(cv_img_ref.cols, cv_img.cols); + + for (int c = 0; c < datum.channels(); ++c) { + for (int h = 0; h < datum.height(); ++h) { + for (int w = 0; w < datum.width(); ++w) { + EXPECT_TRUE(cv_img.at(h, w)[c]== + cv_img_ref.at(h, w)[c]); + } + } + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_lrn_layer.cpp b/caffe-crfrnn/src/caffe/test/test_lrn_layer.cpp new file mode 100644 index 00000000..07425df9 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_lrn_layer.cpp @@ -0,0 +1,212 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +using std::min; +using std::max; + +namespace caffe { + +template +class LRNLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + LRNLayerTest() + : epsilon_(Dtype(1e-5)), + blob_bottom_(new Blob()), + blob_top_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701); + blob_bottom_->Reshape(2, 7, 3, 3); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~LRNLayerTest() { delete blob_bottom_; delete blob_top_; } + void ReferenceLRNForward(const Blob& blob_bottom, + const LayerParameter& layer_param, Blob* blob_top); + + Dtype epsilon_; + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +template +void LRNLayerTest::ReferenceLRNForward( + const Blob& blob_bottom, const LayerParameter& layer_param, + Blob* blob_top) { + typedef typename TypeParam::Dtype Dtype; + blob_top->Reshape(blob_bottom.num(), blob_bottom.channels(), + blob_bottom.height(), blob_bottom.width()); + Dtype* top_data = blob_top->mutable_cpu_data(); + LRNParameter lrn_param = layer_param.lrn_param(); + Dtype alpha = lrn_param.alpha(); + Dtype beta = lrn_param.beta(); + int size = lrn_param.local_size(); + switch (lrn_param.norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + for (int n = 0; n < blob_bottom.num(); ++n) { + for (int c = 0; c < blob_bottom.channels(); ++c) { + for (int h = 0; h < blob_bottom.height(); ++h) { + for (int w = 0; w < blob_bottom.width(); ++w) { + int c_start = c - (size - 1) / 2; + int c_end = min(c_start + size, blob_bottom.channels()); + c_start = max(c_start, 0); + Dtype scale = 1.; + for (int i = c_start; i < c_end; ++i) { + Dtype value = blob_bottom.data_at(n, i, h, w); + scale += value * value * alpha / size; + } + *(top_data + blob_top->offset(n, c, h, w)) = + blob_bottom.data_at(n, c, h, w) / pow(scale, beta); + } + } + } + } + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + for (int n = 0; n < blob_bottom.num(); ++n) { + for (int c = 0; c < blob_bottom.channels(); ++c) { + for (int h = 0; h < blob_bottom.height(); ++h) { + int h_start = h - (size - 1) / 2; + int h_end = min(h_start + size, blob_bottom.height()); + h_start = max(h_start, 0); + for (int w = 0; w < blob_bottom.width(); ++w) { + Dtype scale = 1.; + int w_start = w - (size - 1) / 2; + int w_end = min(w_start + size, blob_bottom.width()); + w_start = max(w_start, 0); + for (int nh = h_start; nh < h_end; ++nh) { + for (int nw = w_start; nw < w_end; ++nw) { + Dtype value = blob_bottom.data_at(n, c, nh, nw); + scale += value * value * alpha / (size * size); + } + } + *(top_data + blob_top->offset(n, c, h, w)) = + blob_bottom.data_at(n, c, h, w) / pow(scale, beta); + } + } + } + } + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +TYPED_TEST_CASE(LRNLayerTest, TestDtypesAndDevices); + +TYPED_TEST(LRNLayerTest, TestSetupAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 7); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 3); +} + +TYPED_TEST(LRNLayerTest, TestForwardAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Blob top_reference; + this->ReferenceLRNForward(*(this->blob_bottom_), layer_param, + &top_reference); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } +} + +TYPED_TEST(LRNLayerTest, TestGradientAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + LRNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + // for (int i = 0; i < this->blob_bottom_->count(); ++i) { + // std::cout << "CPU diff " << this->blob_bottom_->cpu_diff()[i] + // << std::endl; + // } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(LRNLayerTest, TestSetupWithinChannel) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_norm_region( + LRNParameter_NormRegion_WITHIN_CHANNEL); + layer_param.mutable_lrn_param()->set_local_size(3); + LRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 2); + EXPECT_EQ(this->blob_top_->channels(), 7); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 3); +} + +TYPED_TEST(LRNLayerTest, TestForwardWithinChannel) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_norm_region( + LRNParameter_NormRegion_WITHIN_CHANNEL); + layer_param.mutable_lrn_param()->set_local_size(3); + LRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Blob top_reference; + this->ReferenceLRNForward(*(this->blob_bottom_), layer_param, + &top_reference); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } +} + +TYPED_TEST(LRNLayerTest, TestGradientWithinChannel) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_norm_region( + LRNParameter_NormRegion_WITHIN_CHANNEL); + layer_param.mutable_lrn_param()->set_local_size(3); + LRNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_math_functions.cpp b/caffe-crfrnn/src/caffe/test/test_math_functions.cpp new file mode 100644 index 00000000..667f744b --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_math_functions.cpp @@ -0,0 +1,235 @@ +#include // for uint32_t & uint64_t +#include +#include +#include // for std::fabs +#include // for rand_r + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MathFunctionsTest : public ::testing::Test { + protected: + MathFunctionsTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()) { + } + + virtual void SetUp() { + Caffe::set_random_seed(1701); + this->blob_bottom_->Reshape(11, 17, 19, 23); + this->blob_top_->Reshape(11, 17, 19, 23); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + filler.Fill(this->blob_top_); + } + + virtual ~MathFunctionsTest() { + delete blob_bottom_; + delete blob_top_; + } + + // http://en.wikipedia.org/wiki/Hamming_distance + int ReferenceHammingDistance(const int n, const Dtype* x, const Dtype* y) { + int dist = 0; + uint64_t val; + for (int i = 0; i < n; ++i) { + if (sizeof(Dtype) == 8) { + val = static_cast(x[i]) ^ static_cast(y[i]); + } else if (sizeof(Dtype) == 4) { + val = static_cast(x[i]) ^ static_cast(y[i]); + } else { + LOG(FATAL) << "Unrecognized Dtype size: " << sizeof(Dtype); + } + // Count the number of set bits + while (val) { + ++dist; + val &= val - 1; + } + } + return dist; + } + + Blob* const blob_bottom_; + Blob* const blob_top_; +}; + +TYPED_TEST_CASE(MathFunctionsTest, TestDtypes); + +TYPED_TEST(MathFunctionsTest, TestNothing) { + // The first test case of a test suite takes the longest time + // due to the set up overhead. +} + +TYPED_TEST(MathFunctionsTest, TestHammingDistanceCPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + const TypeParam* y = this->blob_top_->cpu_data(); + EXPECT_EQ(this->ReferenceHammingDistance(n, x, y), + caffe_cpu_hamming_distance(n, x, y)); +} + +TYPED_TEST(MathFunctionsTest, TestAsumCPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + TypeParam std_asum = 0; + for (int i = 0; i < n; ++i) { + std_asum += std::fabs(x[i]); + } + TypeParam cpu_asum = caffe_cpu_asum(n, x); + EXPECT_LT((cpu_asum - std_asum) / std_asum, 1e-2); +} + +TYPED_TEST(MathFunctionsTest, TestSignCPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_sign(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +TYPED_TEST(MathFunctionsTest, TestSgnbitCPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_sgnbit(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* signbits = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(signbits[i], x[i] < 0 ? 1 : 0); + } +} + +TYPED_TEST(MathFunctionsTest, TestFabsCPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_abs(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]); + } +} + +TYPED_TEST(MathFunctionsTest, TestScaleCPU) { + int n = this->blob_bottom_->count(); + TypeParam alpha = this->blob_bottom_->cpu_diff()[caffe_rng_rand() % + this->blob_bottom_->count()]; + caffe_cpu_scale(n, alpha, this->blob_bottom_->cpu_data(), + this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* scaled = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(scaled[i], x[i] * alpha); + } +} + +TYPED_TEST(MathFunctionsTest, TestCopyCPU) { + const int n = this->blob_bottom_->count(); + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + TypeParam* top_data = this->blob_top_->mutable_cpu_data(); + Caffe::set_mode(Caffe::CPU); + caffe_copy(n, bottom_data, top_data); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(bottom_data[i], top_data[i]); + } +} + +#ifndef CPU_ONLY + +// TODO: Fix caffe_gpu_hamming_distance and re-enable this test. +TYPED_TEST(MathFunctionsTest, DISABLED_TestHammingDistanceGPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + const TypeParam* y = this->blob_top_->cpu_data(); + int reference_distance = this->ReferenceHammingDistance(n, x, y); + x = this->blob_bottom_->gpu_data(); + y = this->blob_top_->gpu_data(); + int computed_distance = caffe_gpu_hamming_distance(n, x, y); + EXPECT_EQ(reference_distance, computed_distance); +} + +TYPED_TEST(MathFunctionsTest, TestAsumGPU) { + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + TypeParam std_asum = 0; + for (int i = 0; i < n; ++i) { + std_asum += std::fabs(x[i]); + } + TypeParam gpu_asum; + caffe_gpu_asum(n, this->blob_bottom_->gpu_data(), &gpu_asum); + EXPECT_LT((gpu_asum - std_asum) / std_asum, 1e-2); +} + +TYPED_TEST(MathFunctionsTest, TestSignGPU) { + int n = this->blob_bottom_->count(); + caffe_gpu_sign(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +TYPED_TEST(MathFunctionsTest, TestSgnbitGPU) { + int n = this->blob_bottom_->count(); + caffe_gpu_sgnbit(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* signbits = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(signbits[i], x[i] < 0 ? 1 : 0); + } +} + +TYPED_TEST(MathFunctionsTest, TestFabsGPU) { + int n = this->blob_bottom_->count(); + caffe_gpu_abs(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* abs_val = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(abs_val[i], x[i] > 0 ? x[i] : -x[i]); + } +} + +TYPED_TEST(MathFunctionsTest, TestScaleGPU) { + int n = this->blob_bottom_->count(); + TypeParam alpha = this->blob_bottom_->cpu_diff()[caffe_rng_rand() % + this->blob_bottom_->count()]; + caffe_gpu_scale(n, alpha, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* scaled = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(scaled[i], x[i] * alpha); + } +} + +TYPED_TEST(MathFunctionsTest, TestCopyGPU) { + const int n = this->blob_bottom_->count(); + const TypeParam* bottom_data = this->blob_bottom_->gpu_data(); + TypeParam* top_data = this->blob_top_->mutable_gpu_data(); + Caffe::set_mode(Caffe::GPU); + caffe_copy(n, bottom_data, top_data); + bottom_data = this->blob_bottom_->cpu_data(); + top_data = this->blob_top_->mutable_cpu_data(); + for (int i = 0; i < n; ++i) { + EXPECT_EQ(bottom_data[i], top_data[i]); + } +} + +#endif + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_maxpool_dropout_layers.cpp b/caffe-crfrnn/src/caffe/test/test_maxpool_dropout_layers.cpp new file mode 100644 index 00000000..b1f4e4ea --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_maxpool_dropout_layers.cpp @@ -0,0 +1,127 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class MaxPoolingDropoutTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + MaxPoolingDropoutTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1703); + blob_bottom_->Reshape(2, 3, 6, 5); + // fill the values + FillerParameter filler_param; + filler_param.set_value(1.); + ConstantFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~MaxPoolingDropoutTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(MaxPoolingDropoutTest, TestDtypesAndDevices); + +TYPED_TEST(MaxPoolingDropoutTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + PoolingLayer max_layer(layer_param); + max_layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + DropoutLayer dropout_layer(layer_param); + dropout_layer.SetUp(this->blob_top_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 2); +} + + +TYPED_TEST(MaxPoolingDropoutTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* top_data = this->blob_top_->cpu_data(); + Dtype sum = 0.; + for (int i = 0; i < this->blob_top_->count(); ++i) { + sum += top_data[i]; + } + EXPECT_EQ(sum, this->blob_top_->count()); + // Dropout in-place + DropoutLayer dropout_layer(layer_param); + dropout_layer.SetUp(this->blob_top_vec_, this->blob_top_vec_); + dropout_layer.Forward(this->blob_top_vec_, this->blob_top_vec_); + sum = 0.; + Dtype scale = 1. / (1. - layer_param.dropout_param().dropout_ratio()); + top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_top_->count(); ++i) { + sum += top_data[i]; + } + EXPECT_GE(sum, 0); + EXPECT_LE(sum, this->blob_top_->count()*scale); +} + +TYPED_TEST(MaxPoolingDropoutTest, TestBackward) { + typedef typename TypeParam::Dtype Dtype; + Caffe::set_phase(Caffe::TRAIN); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + const Dtype* bottom_diff = this->blob_bottom_->cpu_diff(); + Dtype sum = 0.; + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + sum += bottom_diff[i]; + } + EXPECT_EQ(sum, this->blob_top_->count()); + // Dropout in-place + DropoutLayer dropout_layer(layer_param); + dropout_layer.SetUp(this->blob_top_vec_, this->blob_top_vec_); + dropout_layer.Forward(this->blob_top_vec_, this->blob_top_vec_); + dropout_layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_top_vec_); + layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + Dtype sum_with_dropout = 0.; + bottom_diff = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + sum_with_dropout += bottom_diff[i]; + } + EXPECT_GE(sum_with_dropout, sum); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_memory_data_layer.cpp b/caffe-crfrnn/src/caffe/test/test_memory_data_layer.cpp new file mode 100644 index 00000000..497ab0d1 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_memory_data_layer.cpp @@ -0,0 +1,167 @@ +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class MemoryDataLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + MemoryDataLayerTest() + : data_(new Blob()), + labels_(new Blob()), + data_blob_(new Blob()), + label_blob_(new Blob()) {} + virtual void SetUp() { + batch_size_ = 8; + batches_ = 12; + channels_ = 4; + height_ = 7; + width_ = 11; + blob_top_vec_.push_back(data_blob_); + blob_top_vec_.push_back(label_blob_); + // pick random input data + FillerParameter filler_param; + GaussianFiller filler(filler_param); + data_->Reshape(batches_ * batch_size_, channels_, height_, width_); + labels_->Reshape(batches_ * batch_size_, 1, 1, 1); + filler.Fill(this->data_); + filler.Fill(this->labels_); + } + + virtual ~MemoryDataLayerTest() { + delete data_blob_; + delete label_blob_; + delete data_; + delete labels_; + } + int batch_size_; + int batches_; + int channels_; + int height_; + int width_; + // we don't really need blobs for the input data, but it makes it + // easier to call Filler + Blob* const data_; + Blob* const labels_; + // blobs for the top of MemoryDataLayer + Blob* const data_blob_; + Blob* const label_blob_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(MemoryDataLayerTest, TestDtypesAndDevices); + +TYPED_TEST(MemoryDataLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + + LayerParameter layer_param; + MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); + md_param->set_batch_size(this->batch_size_); + md_param->set_channels(this->channels_); + md_param->set_height(this->height_); + md_param->set_width(this->width_); + shared_ptr > layer( + new MemoryDataLayer(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->data_blob_->num(), this->batch_size_); + EXPECT_EQ(this->data_blob_->channels(), this->channels_); + EXPECT_EQ(this->data_blob_->height(), this->height_); + EXPECT_EQ(this->data_blob_->width(), this->width_); + EXPECT_EQ(this->label_blob_->num(), this->batch_size_); + EXPECT_EQ(this->label_blob_->channels(), 1); + EXPECT_EQ(this->label_blob_->height(), 1); + EXPECT_EQ(this->label_blob_->width(), 1); +} + +// run through a few batches and check that the right data appears +TYPED_TEST(MemoryDataLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + + LayerParameter layer_param; + MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); + md_param->set_batch_size(this->batch_size_); + md_param->set_channels(this->channels_); + md_param->set_height(this->height_); + md_param->set_width(this->width_); + shared_ptr > layer( + new MemoryDataLayer(layer_param)); + layer->DataLayerSetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer->Reset(this->data_->mutable_cpu_data(), + this->labels_->mutable_cpu_data(), this->data_->num()); + for (int i = 0; i < this->batches_ * 6; ++i) { + int batch_num = i % this->batches_; + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int j = 0; j < this->data_blob_->count(); ++j) { + EXPECT_EQ(this->data_blob_->cpu_data()[j], + this->data_->cpu_data()[ + this->data_->offset(1) * this->batch_size_ * batch_num + j]); + } + for (int j = 0; j < this->label_blob_->count(); ++j) { + EXPECT_EQ(this->label_blob_->cpu_data()[j], + this->labels_->cpu_data()[this->batch_size_ * batch_num + j]); + } + } +} + +TYPED_TEST(MemoryDataLayerTest, AddDatumVectorDefaultTransform) { + typedef typename TypeParam::Dtype Dtype; + + LayerParameter param; + MemoryDataParameter* memory_data_param = param.mutable_memory_data_param(); + memory_data_param->set_batch_size(this->batch_size_); + memory_data_param->set_channels(this->channels_); + memory_data_param->set_height(this->height_); + memory_data_param->set_width(this->width_); + MemoryDataLayer layer(param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + + vector datum_vector(this->batch_size_); + const size_t count = this->channels_ * this->height_ * this->width_; + size_t pixel_index = 0; + for (int i = 0; i < this->batch_size_; ++i) { + LOG(ERROR) << "i " << i; + datum_vector[i].set_channels(this->channels_); + datum_vector[i].set_height(this->height_); + datum_vector[i].set_width(this->width_); + datum_vector[i].set_label(i); + vector pixels(count); + for (int j = 0; j < count; ++j) { + pixels[j] = pixel_index++ % 256; + } + datum_vector[i].set_data(&(pixels[0]), count); + } + + layer.AddDatumVector(datum_vector); + + int data_index; + // Go through the data 5 times + for (int iter = 0; iter < 5; ++iter) { + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* data = this->data_blob_->cpu_data(); + size_t index = 0; + for (int i = 0; i < this->batch_size_; ++i) { + const string& data_string = datum_vector[i].data(); + EXPECT_EQ(i, this->label_blob_->cpu_data()[i]); + for (int c = 0; c < this->channels_; ++c) { + for (int h = 0; h < this->height_; ++h) { + for (int w = 0; w < this->width_; ++w) { + data_index = (c * this->height_ + h) * this->width_ + w; + EXPECT_EQ(static_cast( + static_cast(data_string[data_index])), + data[index++]); + } + } + } + } + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_multi_stage_meanfield_layer.cpp b/caffe-crfrnn/src/caffe/test/test_multi_stage_meanfield_layer.cpp new file mode 100755 index 00000000..372effe6 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_multi_stage_meanfield_layer.cpp @@ -0,0 +1,79 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/util/tvg_util.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class MultiStageMeanfieldLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + MultiStageMeanfieldLayerTest() {} + + virtual void SetUp() { + + } + + virtual ~MultiStageMeanfieldLayerTest() { + + } +}; + +TYPED_TEST_CASE(MultiStageMeanfieldLayerTest, TestDtypesAndDevices); + + TYPED_TEST(MultiStageMeanfieldLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + const int n = 5, c = 3, H = 5, W = 5; + + if (sizeof(Dtype) != sizeof(float)) + return; + + Blob unary_terms_blob(n, c, H, W); + Blob previous_output_blob(n, c, H, W); + Blob rgb_blob(n, 3, H, W); + + caffe::FillAsLogProb(&unary_terms_blob); + caffe::FillAsLogProb(&previous_output_blob); + caffe::FillAsRGB(&rgb_blob); + + vector*> bottom_vec, top_vec; + bottom_vec.push_back(&unary_terms_blob); + bottom_vec.push_back(&previous_output_blob); + bottom_vec.push_back(&rgb_blob); + + Blob top_blob; + top_vec.push_back(&top_blob); + + LayerParameter layer_param; + MultiStageMeanfieldParameter* ms_mf_param = layer_param.mutable_multi_stage_meanfield_param(); + ms_mf_param->set_num_iterations(2); + ms_mf_param->set_bilateral_filter_weight(1.0); + ms_mf_param->set_spatial_filter_weight(1.0); + ms_mf_param->set_compatibility_mode(MultiStageMeanfieldParameter_Mode_POTTS); + ms_mf_param->set_theta_alpha(5); + ms_mf_param->set_theta_beta(2); + ms_mf_param->set_theta_gamma(3); + + MultiStageMeanfieldLayer layer(layer_param); + layer.SetUp(bottom_vec, top_vec); + layer.Forward(bottom_vec, top_vec); + + GradientChecker checker(1e-2, 1e-3); + + // Check gradients w.r.t. unary terms + checker.CheckGradientExhaustive(&layer, bottom_vec, top_vec, 0); + + // Check gradients w.r.t. previous outputs + checker.CheckGradientExhaustive(&layer, bottom_vec, top_vec, 1); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_multinomial_logistic_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_multinomial_logistic_loss_layer.cpp new file mode 100644 index 00000000..9038017e --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_multinomial_logistic_loss_layer.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class MultinomialLogisticLossLayerTest : public ::testing::Test { + protected: + MultinomialLogisticLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 1, 1, 1)), + blob_top_loss_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + PositiveUnitballFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + for (int i = 0; i < blob_bottom_label_->count(); ++i) { + blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5; + } + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~MultinomialLogisticLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_loss_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(MultinomialLogisticLossLayerTest, TestDtypes); + + +TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + MultinomialLogisticLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + GradientChecker checker(1e-2, 2*1e-2, 1701, 0, 0.05); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_mvn_layer.cpp b/caffe-crfrnn/src/caffe/test/test_mvn_layer.cpp new file mode 100644 index 00000000..933b4326 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_mvn_layer.cpp @@ -0,0 +1,169 @@ +#include +#include +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "gtest/gtest.h" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class MVNLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + MVNLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~MVNLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(MVNLayerTest, TestDtypesAndDevices); + +TYPED_TEST(MVNLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + MVNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test mean + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + for (int i = 0; i < num; ++i) { + for (int j = 0; j < channels; ++j) { + Dtype sum = 0, var = 0; + for (int k = 0; k < height; ++k) { + for (int l = 0; l < width; ++l) { + Dtype data = this->blob_top_->data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + sum /= height * width; + var /= height * width; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + // expect unit variance + EXPECT_NEAR(1, var, kErrorBound); + } + } +} + +TYPED_TEST(MVNLayerTest, TestForwardMeanOnly) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.ParseFromString("mvn_param{normalize_variance: false}"); + MVNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test mean + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + for (int i = 0; i < num; ++i) { + for (int j = 0; j < channels; ++j) { + Dtype sum = 0, var = 0; + for (int k = 0; k < height; ++k) { + for (int l = 0; l < width; ++l) { + Dtype data = this->blob_top_->data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + sum /= height * width; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + } + } +} + +TYPED_TEST(MVNLayerTest, TestForwardAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.ParseFromString("mvn_param{across_channels: true}"); + MVNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test mean + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + for (int i = 0; i < num; ++i) { + Dtype sum = 0, var = 0; + for (int j = 0; j < channels; ++j) { + for (int k = 0; k < height; ++k) { + for (int l = 0; l < width; ++l) { + Dtype data = this->blob_top_->data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + } + sum /= height * width * channels; + var /= height * width * channels; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + // expect unit variance + EXPECT_NEAR(1, var, kErrorBound); + } +} + +TYPED_TEST(MVNLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + MVNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(MVNLayerTest, TestGradientMeanOnly) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.ParseFromString("mvn_param{normalize_variance: false}"); + MVNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(MVNLayerTest, TestGradientAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.ParseFromString("mvn_param{across_channels: true}"); + MVNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_net.cpp b/caffe-crfrnn/src/caffe/test/test_net.cpp new file mode 100644 index 00000000..319958fe --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_net.cpp @@ -0,0 +1,2204 @@ +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/net.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class NetTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + NetTest() : seed_(1701) {} + + virtual void InitNetFromProtoString(const string& proto) { + NetParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + net_.reset(new Net(param)); + } + + virtual void CopyNetBlobs(const bool copy_diff, + vector > >* blobs_copy) { + CHECK(net_); + const vector > >& net_blobs = net_->blobs(); + blobs_copy->clear(); + blobs_copy->resize(net_blobs.size()); + const bool kReshape = true; + for (int i = 0; i < net_blobs.size(); ++i) { + (*blobs_copy)[i].reset(new Blob()); + (*blobs_copy)[i]->CopyFrom(*net_blobs[i], copy_diff, kReshape); + } + } + + virtual void CopyNetParams(const bool copy_diff, + vector > >* params_copy) { + CHECK(net_); + const vector > >& net_params = net_->params(); + params_copy->clear(); + params_copy->resize(net_params.size()); + const bool kReshape = true; + for (int i = 0; i < net_params.size(); ++i) { + (*params_copy)[i].reset(new Blob()); + (*params_copy)[i]->CopyFrom(*net_params[i], copy_diff, kReshape); + } + } + + virtual void InitTinyNet(const bool force_backward = false, + const bool accuracy_layer = false) { + string proto = + "name: 'TinyTestNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " data_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'innerproduct' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerproduct' " + " bottom: 'label' " + " top: 'top_loss' " + "} "; + if (accuracy_layer) { + proto += + "layers: { " + " name: 'loss' " + " type: ACCURACY " + " bottom: 'innerproduct' " + " bottom: 'label' " + " top: 'accuracy' " + "} "; + } + if (force_backward) { + proto += "force_backward: true "; + } + InitNetFromProtoString(proto); + } + + virtual void InitTinyNetEuclidean(const bool force_backward = false) { + string proto = + "name: 'TinyTestEuclidLossNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'innerproduct' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerproduct' " + " bottom: 'label' " + "} "; + if (force_backward) { + proto += "force_backward: true "; + } + InitNetFromProtoString(proto); + } + + virtual void InitTrickyNet(Dtype* loss_weight = NULL) { + ostringstream loss_weight_stream; + if (loss_weight) { + loss_weight_stream << " loss_weight: " << *loss_weight << " "; + } + const string& proto = + "name: 'TrickyTestNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'transformed_data' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'label' " + " top: 'transformed_label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + + loss_weight_stream.str() + + " bottom: 'transformed_data' " + " bottom: 'transformed_label' " + "} "; + InitNetFromProtoString(proto); + } + + // loss_weight is the loss weight for the EUCLIDEAN_LOSS layer output. + // midnet_loss_weight is the loss weight for the first INNER_PRODUCT layer + // output. Should both default to 0.0 if unspecified (i.e., if NULL is + // passed to this function). + virtual void InitUnsharedWeightsNet(const Dtype* loss_weight = NULL, + const Dtype* midnet_loss_weight = NULL, + const bool force_backward = false, const bool bias_term = false, + const Dtype blobs_lr_w1 = 1, const Dtype blobs_lr_b1 = 2, + const Dtype blobs_lr_w2 = 1, const Dtype blobs_lr_b2 = 2) { + ostringstream proto; + proto << "name: 'UnsharedWeightsNetwork' "; + if (force_backward) { + proto << "force_backward: true "; + } + proto << + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: " << bias_term << + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights1' "; + if (bias_term) { + proto << " param: '' "; + } + proto << + " blobs_lr: " << blobs_lr_w1; + if (bias_term) { + proto << " blobs_lr: " << blobs_lr_b1; + } + proto << + " bottom: 'data' " + " top: 'innerproduct1' "; + if (midnet_loss_weight) { + proto << " loss_weight: " << *midnet_loss_weight << " "; + } + proto << + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: " << bias_term << + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights2' "; + if (bias_term) { + proto << " param: '' "; + } + proto << + " bottom: 'data' " + " blobs_lr: " << blobs_lr_w2; + if (bias_term) { + proto << " blobs_lr: " << blobs_lr_b2; + } + proto << + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS "; + if (loss_weight) { + proto << " loss_weight: " << *loss_weight << " "; + } + proto << + " bottom: 'innerproduct1' " + " bottom: 'innerproduct2' " + "} "; + InitNetFromProtoString(proto.str()); + } + + virtual void InitSharedWeightsNet() { + const string& proto = + "name: 'SharedWeightsNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'sharedweights' " + " bottom: 'data' " + " top: 'innerproduct1' " + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'sharedweights' " + " bottom: 'data' " + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerproduct1' " + " bottom: 'innerproduct2' " + "} "; + InitNetFromProtoString(proto); + } + + virtual void InitDiffDataUnsharedWeightsNet() { + const string& proto = + "name: 'DiffDataUnsharedWeightsNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 10 " + " channels: 10 " + " height: 1 " + " width: 1 " + " num: 10 " + " channels: 10 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " top: 'data1' " + " top: 'data2' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'constant' " + " value: 0.5 " + " } " + " } " + " param: 'unsharedweights1' " + " bottom: 'data1' " + " top: 'innerproduct1' " + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'constant' " + " value: 0.5 " + " } " + " } " + " param: 'unsharedweights2' " + " bottom: 'innerproduct1' " + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'data2' " + " bottom: 'innerproduct2' " + "} "; + InitNetFromProtoString(proto); + } + + virtual void InitDiffDataSharedWeightsNet() { + const string& proto = + "name: 'DiffDataSharedWeightsNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 10 " + " channels: 10 " + " height: 1 " + " width: 1 " + " num: 10 " + " channels: 10 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " top: 'data1' " + " top: 'data2' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'constant' " + " value: 0.5 " + " } " + " } " + " param: 'sharedweights' " + " bottom: 'data1' " + " top: 'innerproduct1' " + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'constant' " + " value: 0.5 " + " } " + " } " + " param: 'sharedweights' " + " bottom: 'innerproduct1' " + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'data2' " + " bottom: 'innerproduct2' " + "} "; + InitNetFromProtoString(proto); + } + + virtual void InitReshapableNet() { + const string& proto = + "name: 'ReshapableNetwork' " + "input: 'data' " + "input_dim: 1 " + "input_dim: 3 " + "input_dim: 100 " + "input_dim: 100 " + "layers: { " + " name: 'conv1' " + " type: CONVOLUTION " + " bottom: 'data' " + " top: 'conv1' " + " convolution_param { " + " num_output: 5 " + " kernel_size: 3 " + " stride: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0.2 " + " } " + " } " + "} " + "layers: { " + " name: 'relu1' " + " type: RELU " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers: { " + " name: 'pool1' " + " type: POOLING " + " bottom: 'conv1' " + " top: 'pool1' " + " pooling_param { " + " pool: MAX " + " kernel_size: 2 " + " stride: 2 " + " } " + "} " + "layers: { " + " name: 'norm1' " + " type: LRN " + " bottom: 'pool1' " + " top: 'norm1' " + " lrn_param { " + " local_size: 3 " + " } " + "} " + "layers: { " + " name: 'softmax' " + " type: SOFTMAX " + " bottom: 'norm1' " + " top: 'softmax' " + "} "; + InitNetFromProtoString(proto); + } + + int seed_; + shared_ptr > net_; +}; + +TYPED_TEST_CASE(NetTest, TestDtypesAndDevices); + +TYPED_TEST(NetTest, TestHasBlob) { + this->InitTinyNet(); + EXPECT_TRUE(this->net_->has_blob("data")); + EXPECT_TRUE(this->net_->has_blob("label")); + EXPECT_TRUE(this->net_->has_blob("innerproduct")); + EXPECT_FALSE(this->net_->has_blob("loss")); + EXPECT_TRUE(this->net_->has_blob("top_loss")); +} + +TYPED_TEST(NetTest, TestGetBlob) { + this->InitTinyNet(); + EXPECT_EQ(this->net_->blob_by_name("data"), this->net_->blobs()[0]); + EXPECT_EQ(this->net_->blob_by_name("label"), this->net_->blobs()[1]); + EXPECT_EQ(this->net_->blob_by_name("innerproduct"), this->net_->blobs()[2]); + EXPECT_FALSE(this->net_->blob_by_name("loss")); + EXPECT_EQ(this->net_->blob_by_name("top_loss"), this->net_->blobs()[3]); +} + +TYPED_TEST(NetTest, TestHasLayer) { + this->InitTinyNet(); + EXPECT_TRUE(this->net_->has_layer("data")); + EXPECT_TRUE(this->net_->has_layer("innerproduct")); + EXPECT_TRUE(this->net_->has_layer("loss")); + EXPECT_FALSE(this->net_->has_layer("label")); +} + +TYPED_TEST(NetTest, TestGetLayerByName) { + this->InitTinyNet(); + EXPECT_EQ(this->net_->layer_by_name("data"), this->net_->layers()[0]); + EXPECT_EQ(this->net_->layer_by_name("innerproduct"), this->net_->layers()[1]); + EXPECT_EQ(this->net_->layer_by_name("loss"), this->net_->layers()[2]); + EXPECT_FALSE(this->net_->layer_by_name("label")); +} + +TYPED_TEST(NetTest, TestBottomNeedBackward) { + this->InitTinyNet(); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(3, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(false, bottom_need_backward[1][0]); + EXPECT_EQ(2, bottom_need_backward[2].size()); + EXPECT_EQ(true, bottom_need_backward[2][0]); + EXPECT_EQ(false, bottom_need_backward[2][1]); +} + +TYPED_TEST(NetTest, TestBottomNeedBackwardForce) { + const bool force_backward = true; + this->InitTinyNet(force_backward); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(3, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(true, bottom_need_backward[1][0]); + EXPECT_EQ(2, bottom_need_backward[2].size()); + EXPECT_EQ(true, bottom_need_backward[2][0]); + EXPECT_EQ(false, bottom_need_backward[2][1]); +} + +TYPED_TEST(NetTest, TestBottomNeedBackwardEuclideanForce) { + const bool force_backward = true; + this->InitTinyNetEuclidean(force_backward); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(3, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(true, bottom_need_backward[1][0]); + EXPECT_EQ(2, bottom_need_backward[2].size()); + EXPECT_EQ(true, bottom_need_backward[2][0]); + EXPECT_EQ(true, bottom_need_backward[2][1]); +} + +TYPED_TEST(NetTest, TestBottomNeedBackwardTricky) { + this->InitTrickyNet(); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(4, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(false, bottom_need_backward[1][0]); + EXPECT_EQ(1, bottom_need_backward[2].size()); + EXPECT_EQ(false, bottom_need_backward[2][0]); + EXPECT_EQ(2, bottom_need_backward[3].size()); + EXPECT_EQ(true, bottom_need_backward[3][0]); + // The label input to the SoftmaxLossLayer should say it "needs backward" + // since it has weights under it, even though we expect this to cause a crash + // at training/test time. + EXPECT_EQ(true, bottom_need_backward[3][1]); +} + +TYPED_TEST(NetTest, TestLossWeight) { + typedef typename TypeParam::Dtype Dtype; + // First, compute the loss and gradients with no loss_weight specified. + // In this case, the loss weight for the EUCLIDEAN_LOSS layer should default + // to 1. + vector*> bottom; + Caffe::set_random_seed(this->seed_); + const bool kForceBackward = true; + this->InitUnsharedWeightsNet(NULL, NULL, kForceBackward); + const Dtype loss = this->net_->ForwardBackward(bottom); + const bool kCopyDiff = true; + vector > > blob_grads; + this->CopyNetBlobs(kCopyDiff, &blob_grads); + vector > > param_grads; + this->CopyNetParams(kCopyDiff, ¶m_grads); + // Check that the loss is non-trivial, otherwise the test doesn't prove much. + const Dtype kMinLossAbsValue = 1e-2; + ASSERT_GE(fabs(loss), kMinLossAbsValue); + const Dtype kErrorMargin = 1e-4; + const int kNumLossWeights = 6; + Dtype kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7}; + for (int i = 0; i < kNumLossWeights; ++i) { + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&kLossWeights[i], NULL, kForceBackward); + const Dtype weighted_loss = this->net_->ForwardBackward(bottom); + const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]); + EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin) + << "loss weight = " << kLossWeights[i]; + const vector > >& weighted_blobs = + this->net_->blobs(); + ASSERT_EQ(blob_grads.size(), weighted_blobs.size()); + for (int j = 0; j < blob_grads.size(); ++j) { + ASSERT_EQ(blob_grads[j]->count(), weighted_blobs[j]->count()); + for (int k = 0; k < blob_grads[j]->count(); ++k) { + EXPECT_NEAR(blob_grads[j]->cpu_diff()[k] * kLossWeights[i], + weighted_blobs[j]->cpu_diff()[k], error_margin); + } + } + const vector > >& weighted_params = + this->net_->params(); + ASSERT_EQ(param_grads.size(), weighted_params.size()); + for (int j = 0; j < param_grads.size(); ++j) { + ASSERT_EQ(param_grads[j]->count(), weighted_params[j]->count()); + for (int k = 0; k < param_grads[j]->count(); ++k) { + EXPECT_NEAR(param_grads[j]->cpu_diff()[k] * kLossWeights[i], + weighted_params[j]->cpu_diff()[k], error_margin); + } + } + } +} + +TYPED_TEST(NetTest, TestLossWeightMidNet) { + typedef typename TypeParam::Dtype Dtype; + vector*> bottom; + Caffe::set_random_seed(this->seed_); + const bool kForceBackward = true; + Dtype loss_weight = 0; + Dtype midnet_loss_weight = 1; + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss = this->net_->ForwardBackward(bottom); + const bool kCopyDiff = true; + const bool kReshape = true; + Blob data_grad; + data_grad.CopyFrom(*this->net_->blob_by_name("data"), kCopyDiff, kReshape); + // Check that the loss is non-trivial, otherwise the test doesn't prove much. + const Dtype kMinLossAbsValue = 1e-2; + ASSERT_GE(fabs(loss), kMinLossAbsValue); + const Dtype kErrorMargin = 1e-4; + const int kNumLossWeights = 6; + Dtype kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7}; + for (int i = 0; i < kNumLossWeights; ++i) { + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &kLossWeights[i], + kForceBackward); + const Dtype weighted_loss = this->net_->ForwardBackward(bottom); + const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]); + EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin) + << "loss weight = " << kLossWeights[i]; + const shared_ptr >& weighted_blob = + this->net_->blob_by_name("data"); + ASSERT_EQ(data_grad.count(), weighted_blob->count()); + for (int j = 0; j < data_grad.count(); ++j) { + EXPECT_NEAR(data_grad.cpu_diff()[j] * kLossWeights[i], + weighted_blob->cpu_diff()[j], error_margin); + } + } +} + +TYPED_TEST(NetTest, TestComboLossWeight) { + typedef typename TypeParam::Dtype Dtype; + vector*> bottom; + Dtype loss_weight; + Dtype midnet_loss_weight; + const bool kForceBackward = true; + const Dtype kErrorMargin = 1e-4; + + // Get the loss and gradients with EUCLIDEAN_LOSS weight 1, + // INNER_PRODUCT weight 1. + loss_weight = 1; + midnet_loss_weight = 1; + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss = this->net_->ForwardBackward(bottom); + const bool kCopyDiff = true; + vector > > blob_grads; + this->CopyNetBlobs(kCopyDiff, &blob_grads); + vector > > param_grads; + this->CopyNetParams(kCopyDiff, ¶m_grads); + + loss_weight = 2; + midnet_loss_weight = 1; + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss_main_2 = this->net_->ForwardBackward(bottom); + vector > > blob_grads_loss_2; + this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2); + vector > > param_grads_loss_2; + this->CopyNetParams(kCopyDiff, ¶m_grads_loss_2); + + loss_weight = 3; + midnet_loss_weight = 1; + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss_main_3 = this->net_->ForwardBackward(bottom); + const vector > >& blob_grads_loss_3 = + this->net_->blobs(); + ASSERT_EQ(blob_grads.size(), blob_grads_loss_3.size()); + ASSERT_EQ(blob_grads_loss_2.size(), blob_grads_loss_3.size()); + for (int j = 0; j < blob_grads.size(); ++j) { + const string& blob_name = this->net_->blob_names()[j]; + bool grad_should_change = true; + if (blob_name == "innerproduct1_innerproduct1_0_split_0") { + grad_should_change = false; + } + ASSERT_EQ(blob_grads[j]->count(), blob_grads_loss_3[j]->count()); + ASSERT_EQ(blob_grads_loss_2[j]->count(), blob_grads_loss_3[j]->count()); + for (int k = 0; k < blob_grads[j]->count(); ++k) { + const Dtype grad_diff_2 = blob_grads_loss_2[j]->cpu_diff()[k] - + blob_grads[j]->cpu_diff()[k]; + const Dtype grad_diff_3 = blob_grads_loss_3[j]->cpu_diff()[k] - + blob_grads[j]->cpu_diff()[k]; + if (grad_should_change) { + // Test non-triviality. + const Dtype kMinGradDiffAbsValue = 1e-4; + EXPECT_GT(fabs(grad_diff_2), kMinGradDiffAbsValue) << blob_name; + EXPECT_NEAR(2 * grad_diff_2, grad_diff_3, kErrorMargin) << blob_name; + } else { + EXPECT_EQ(0, grad_diff_2) << blob_name; + EXPECT_EQ(0, grad_diff_3) << blob_name; + } + } + } + + loss_weight = 1; + midnet_loss_weight = 2; + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss_midnet_2 = this->net_->ForwardBackward(bottom); + this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2); + this->CopyNetParams(kCopyDiff, ¶m_grads_loss_2); + + loss_weight = 1; + midnet_loss_weight = 3; + Caffe::set_random_seed(this->seed_); + this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, + kForceBackward); + const Dtype loss_midnet_3 = this->net_->ForwardBackward(bottom); + const vector > >& blob_grads_midnet_loss_3 = + this->net_->blobs(); + ASSERT_EQ(blob_grads.size(), blob_grads_midnet_loss_3.size()); + ASSERT_EQ(blob_grads_loss_2.size(), blob_grads_midnet_loss_3.size()); + const vector& blob_names = this->net_->blob_names(); + for (int j = 0; j < blob_grads.size(); ++j) { + const string& blob_name = blob_names[j]; + bool grad_should_change = false; + if (blob_name == "innerproduct1" || + blob_name == "innerproduct1_innerproduct1_0_split_0" || + blob_name == "data_data_0_split_0" || blob_name == "data") { + grad_should_change = true; + } + ASSERT_EQ(blob_grads[j]->count(), blob_grads_midnet_loss_3[j]->count()); + ASSERT_EQ(blob_grads[j]->count(), blob_grads_loss_2[j]->count()); + for (int k = 0; k < blob_grads[j]->count(); ++k) { + const Dtype grad_diff_2 = blob_grads_loss_2[j]->cpu_diff()[k] - + blob_grads[j]->cpu_diff()[k]; + const Dtype grad_diff_3 = blob_grads_midnet_loss_3[j]->cpu_diff()[k] - + blob_grads[j]->cpu_diff()[k]; + if (grad_should_change) { + // Test non-triviality. + const Dtype kMinGradDiffAbsValue = 1e-4; + EXPECT_GT(fabs(grad_diff_2), kMinGradDiffAbsValue) << blob_name; + EXPECT_NEAR(2 * grad_diff_2, grad_diff_3, kErrorMargin) << blob_name; + } else { + EXPECT_EQ(0, grad_diff_2) << blob_name; + EXPECT_EQ(0, grad_diff_3) << blob_name; + } + } + } + + const Dtype kMinLossDiffAbsValue = 1e-4; + + Dtype loss_diff_2 = loss_main_2 - loss; + // Test non-triviality. + EXPECT_GT(fabs(loss_diff_2), kMinLossDiffAbsValue); + Dtype loss_diff_3 = loss_main_3 - loss; + EXPECT_NEAR(2 * loss_diff_2, loss_diff_3, kErrorMargin); + + loss_diff_2 = loss_midnet_2 - loss; + // Test non-triviality. + EXPECT_GT(fabs(loss_diff_2), kMinLossDiffAbsValue); + loss_diff_3 = loss_midnet_3 - loss; + EXPECT_NEAR(2 * loss_diff_2, loss_diff_3, kErrorMargin); +} + +TYPED_TEST(NetTest, TestBackwardWithAccuracyLayer) { + typedef typename TypeParam::Dtype Dtype; + const bool kForceBackward = false; + const bool kAccuracyLayer = true; + this->InitTinyNet(kForceBackward, kAccuracyLayer); + EXPECT_TRUE(this->net_->has_blob("accuracy")); + vector*> bottom; + // Test that we can do Backward even though we have an ACCURACY layer. + this->net_->ForwardBackward(bottom); +} + +TYPED_TEST(NetTest, TestUnsharedWeightsDataNet) { + typedef typename TypeParam::Dtype Dtype; + this->InitUnsharedWeightsNet(); + vector*> bottom; + Dtype loss; + this->net_->Forward(bottom, &loss); + EXPECT_GT(loss, 0); +} + +TYPED_TEST(NetTest, TestSharedWeightsDataNet) { + typedef typename TypeParam::Dtype Dtype; + this->InitSharedWeightsNet(); + vector*> bottom; + Dtype loss; + this->net_->Forward(bottom, &loss); + EXPECT_FLOAT_EQ(loss, 0); +} + +TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) { + typedef typename TypeParam::Dtype Dtype; + this->InitUnsharedWeightsNet(); + vector*> bottom; + Net* net = this->net_.get(); + net->Forward(bottom); + net->Backward(); + Layer* ip1_layer = net->layer_by_name("innerproduct1").get(); + Layer* ip2_layer = net->layer_by_name("innerproduct2").get(); + const int count = ip1_layer->blobs()[0]->count(); + const Dtype* grad1 = ip1_layer->blobs()[0]->cpu_diff(); + const Dtype* grad2 = ip2_layer->blobs()[0]->cpu_diff(); + for (int i = 0; i < count; ++i) { + EXPECT_GT(fabs(grad1[i]), 0); + EXPECT_FLOAT_EQ(-1 * grad1[i], grad2[i]); + } +} + +TYPED_TEST(NetTest, TestSharedWeightsDiffNet) { + typedef typename TypeParam::Dtype Dtype; + this->InitSharedWeightsNet(); + vector*> bottom; + Net* net = this->net_.get(); + Dtype loss; + net->Forward(bottom, &loss); + net->Backward(); + EXPECT_FLOAT_EQ(loss, 0); + Layer* ip1_layer = net->layer_by_name("innerproduct1").get(); + Layer* ip2_layer = net->layer_by_name("innerproduct2").get(); + const int count = ip1_layer->blobs()[0]->count(); + const Dtype* grad1 = ip1_layer->blobs()[0]->cpu_diff(); + const Dtype* grad2 = ip2_layer->blobs()[0]->cpu_diff(); + for (int i = 0; i < count; ++i) { + EXPECT_FLOAT_EQ(0, grad1[i]); + EXPECT_FLOAT_EQ(0, grad2[i]); + } +} + +TYPED_TEST(NetTest, TestSharedWeightsUpdate) { + typedef typename TypeParam::Dtype Dtype; + Caffe::set_random_seed(this->seed_); + this->InitDiffDataSharedWeightsNet(); + vector*> bottom; + EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1"); + EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2"); + Blob* ip1_weights = this->net_->layers()[1]->blobs()[0].get(); + Blob* ip2_weights = this->net_->layers()[2]->blobs()[0].get(); + // Check that data blobs of shared weights share the same location in memory. + EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); + // Check that diff blobs of shared weights are at different locations in + // memory. (The diffs should be accumulated at update time.) + EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); + this->net_->Forward(bottom); + this->net_->Backward(); + // Compute the expected update as the data minus the two diffs. + Blob shared_params; + const bool reshape = true; + const bool copy_diff = false; + shared_params.CopyFrom(*ip1_weights, copy_diff, reshape); + shared_params.CopyFrom(*ip1_weights, !copy_diff, reshape); + const int count = ip1_weights->count(); + // Make sure the diffs are non-trivial. + for (int i = 0; i < count; ++i) { + EXPECT_NE(0, ip1_weights->cpu_diff()[i]); + EXPECT_NE(0, ip2_weights->cpu_diff()[i]); + EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]); + } + caffe_axpy(count, Dtype(1), ip2_weights->cpu_diff(), + shared_params.mutable_cpu_diff()); + caffe_axpy(count, Dtype(-1), shared_params.cpu_diff(), + shared_params.mutable_cpu_data()); + const Dtype* expected_updated_params = shared_params.cpu_data(); + this->net_->Update(); + const Dtype* actual_updated_params = ip1_weights->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(expected_updated_params[i], actual_updated_params[i]); + } + // Check that data blobs of shared weights STILL point to the same memory + // location (because ... who knows). + EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); + + Caffe::set_random_seed(this->seed_); + this->InitDiffDataUnsharedWeightsNet(); + EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1"); + EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2"); + ip1_weights = this->net_->layers()[1]->blobs()[0].get(); + ip2_weights = this->net_->layers()[2]->blobs()[0].get(); + // Check that data and diff blobs of unshared weights are at different + // locations in memory. + EXPECT_NE(ip1_weights->cpu_data(), ip2_weights->cpu_data()); + EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); + this->net_->Forward(bottom); + this->net_->Backward(); + // Compute the expected update. + Blob unshared_params1; + unshared_params1.CopyFrom(*ip1_weights, copy_diff, reshape); + unshared_params1.CopyFrom(*ip1_weights, !copy_diff, reshape); + Blob unshared_params2; + unshared_params2.CopyFrom(*ip2_weights, copy_diff, reshape); + unshared_params2.CopyFrom(*ip2_weights, !copy_diff, reshape); + // Make sure the diffs are non-trivial and sum to the diff in the shared net. + for (int i = 0; i < count; ++i) { + EXPECT_NE(0, ip1_weights->cpu_diff()[i]); + EXPECT_NE(0, ip2_weights->cpu_diff()[i]); + EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]); + EXPECT_EQ(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i], + shared_params.cpu_diff()[i]); + } + caffe_axpy(count, Dtype(-1), ip1_weights->cpu_diff(), + unshared_params1.mutable_cpu_data()); + caffe_axpy(count, Dtype(-1), ip2_weights->cpu_diff(), + unshared_params2.mutable_cpu_data()); + const Dtype* expected_updated_params1 = unshared_params1.cpu_data(); + const Dtype* expected_updated_params2 = unshared_params2.cpu_data(); + this->net_->Update(); + const Dtype* actual_updated_params1 = ip1_weights->cpu_data(); + const Dtype* actual_updated_params2 = ip2_weights->cpu_data(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(expected_updated_params1[i], actual_updated_params1[i]); + EXPECT_EQ(expected_updated_params2[i], actual_updated_params2[i]); + EXPECT_NE(actual_updated_params1[i], actual_updated_params2[i]); + EXPECT_NE(expected_updated_params, expected_updated_params1); + } +} + +TYPED_TEST(NetTest, TestSharedWeightsResume) { + typedef typename TypeParam::Dtype Dtype; + + // Create a net with weight sharing; Update it once. + Caffe::set_random_seed(this->seed_); + this->InitDiffDataSharedWeightsNet(); + vector*> bottom; + EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1"); + EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2"); + Blob* ip1_weights = this->net_->layers()[1]->blobs()[0].get(); + Blob* ip2_weights = this->net_->layers()[2]->blobs()[0].get(); + // Check that data blobs of shared weights share the same location in memory. + EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); + // Check that diff blobs of shared weights are at different locations in + // memory. (The diffs should be accumulated at update time.) + EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); + this->net_->ForwardBackward(bottom); + this->net_->Update(); + Blob shared_params; + const bool kReshape = true; + const bool kCopyDiff = false; + shared_params.CopyFrom(*ip1_weights, kCopyDiff, kReshape); + const int count = ip1_weights->count(); + + // Write the net to a NetParameter, as in Solver::Snapshot. + NetParameter net_param; + this->net_->ToProto(&net_param); + + // Reinitialize the net and copy parameters from net_param, as in + // Solver::Restore. + Caffe::set_random_seed(this->seed_); + this->InitDiffDataSharedWeightsNet(); + this->net_->CopyTrainedLayersFrom(net_param); + ip1_weights = this->net_->layers()[1]->blobs()[0].get(); + ip2_weights = this->net_->layers()[2]->blobs()[0].get(); + ASSERT_FALSE(NULL == ip1_weights); + ASSERT_FALSE(NULL == ip2_weights); + EXPECT_NE(ip1_weights, ip2_weights); + // Check that data blobs of shared weights share the same location in memory. + EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); + for (int i = 0; i < count; ++i) { + EXPECT_FLOAT_EQ(shared_params.cpu_data()[i], ip1_weights->cpu_data()[i]); + } + // Check that diff blobs of shared weights are at different locations in + // memory. (The diffs should be accumulated at update time.) + EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); +} + +TYPED_TEST(NetTest, TestParamPropagateDown) { + typedef typename TypeParam::Dtype Dtype; + vector*> bottom; + const bool kBiasTerm = true, kForceBackward = false; + const Dtype* kLossWeight1 = NULL; + const Dtype* kLossWeight2 = NULL; + + // Run the net with all params learned; check that gradients are non-zero. + Caffe::set_random_seed(this->seed_); + Dtype blobs_lr_w1 = 1, blobs_lr_w2 = 1, blobs_lr_b1 = 2, blobs_lr_b2 = 2; + this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, + kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); + this->net_->Forward(bottom); + this->net_->Backward(); + const vector > >& params = this->net_->params(); + const int num_params = params.size(); + ASSERT_EQ(4, num_params); + const Dtype kNonZeroTestMin = 1e-3; + vector param_asums(params.size()); + for (int i = 0; i < num_params; ++i) { + const Dtype param_asum = + caffe_cpu_asum(params[i]->count(), params[i]->cpu_diff()); + param_asums[i] = param_asum; + EXPECT_GT(param_asum, kNonZeroTestMin); + } + + // Change the learning rates to different non-zero values; should see same + // gradients. + Caffe::set_random_seed(this->seed_); + blobs_lr_w1 *= 2, blobs_lr_w2 *= 2, blobs_lr_b1 *= 2, blobs_lr_b2 *= 2; + this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, + kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); + this->net_->Forward(bottom); + this->net_->Backward(); + const vector > >& params2 = this->net_->params(); + ASSERT_EQ(num_params, params2.size()); + for (int i = 0; i < num_params; ++i) { + const Dtype param_asum = + caffe_cpu_asum(params2[i]->count(), params2[i]->cpu_diff()); + EXPECT_FLOAT_EQ(param_asum, param_asums[i]); + } + + // Change a subset of the learning rates to zero; check that we see zero + // gradients for those. + Caffe::set_random_seed(this->seed_); + blobs_lr_w1 = 1, blobs_lr_w2 = 0, blobs_lr_b1 = 0, blobs_lr_b2 = 1; + this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, + kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); + this->net_->Forward(bottom); + this->net_->Backward(); + const vector > >& params3 = this->net_->params(); + ASSERT_EQ(num_params, params3.size()); + for (int i = 0; i < num_params; ++i) { + const Dtype param_asum = + caffe_cpu_asum(params3[i]->count(), params3[i]->cpu_diff()); + if (i == 1 || i == 2) { + EXPECT_FLOAT_EQ(0, param_asum); + } else { + EXPECT_FLOAT_EQ(param_asum, param_asums[i]); + } + } + + // Change the opposite subset of the learning rates to zero. + Caffe::set_random_seed(this->seed_); + blobs_lr_w1 = 0, blobs_lr_w2 = 1, blobs_lr_b1 = 1, blobs_lr_b2 = 0; + this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, + kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); + this->net_->Forward(bottom); + this->net_->Backward(); + const vector > >& params4 = this->net_->params(); + ASSERT_EQ(num_params, params4.size()); + for (int i = 0; i < num_params; ++i) { + const Dtype param_asum = + caffe_cpu_asum(params4[i]->count(), params4[i]->cpu_diff()); + if (i == 0 || i == 3) { + EXPECT_FLOAT_EQ(0, param_asum); + } else { + EXPECT_FLOAT_EQ(param_asum, param_asums[i]); + } + } +} + +TYPED_TEST(NetTest, TestFromTo) { + typedef typename TypeParam::Dtype Dtype; + this->InitTinyNet(); + + // Run Forward and Backward, recording the data diff and loss. + Blob data; + data.ReshapeLike(*this->net_->blob_by_name("data")); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + data.CopyFrom(*this->net_->blob_by_name("data"), true, true); + const Dtype *loss_ptr = this->net_->output_blobs()[0]->cpu_data(); + Dtype loss = *loss_ptr; + + // Check that combining partial Forwards gives the same loss. + for (int i = 1; i < this->net_->layers().size(); ++i) { + // Note that we skip layer zero to keep the same data. + this->net_->ForwardFromTo(1, 1); + if (i < this->net_->layers().size() - 1) { + this->net_->ForwardFrom(i + 1); + } + EXPECT_EQ(loss, *loss_ptr); + } + + // Check that combining partial Backwards gives the same data diff. + for (int i = 1; i < this->net_->layers().size(); ++i) { + this->net_->BackwardTo(i); + this->net_->BackwardFrom(i - 1); + for (int j = 0; j < data.count(); ++j) { + EXPECT_EQ(data.cpu_diff()[j], + this->net_->blob_by_name("data")->cpu_diff()[j]); + } + } +} + +class FilterNetTest : public ::testing::Test { + protected: + void RunFilterNetTest( + const string& input_param_string, const string& filtered_param_string) { + NetParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + NetParameter expected_filtered_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + filtered_param_string, &expected_filtered_param)); + NetParameter actual_filtered_param; + Net::FilterNet(input_param, &actual_filtered_param); + EXPECT_EQ(expected_filtered_param.DebugString(), + actual_filtered_param.DebugString()); + // Also test idempotence. + NetParameter double_filtered_param; + Net::FilterNet(actual_filtered_param, &double_filtered_param); + EXPECT_EQ(actual_filtered_param.DebugString(), + double_filtered_param.DebugString()); + } +}; + +TEST_F(FilterNetTest, TestNoFilter) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterLeNetTrainTest) { + const string& input_proto = + "name: 'LeNet' " + "layers { " + " name: 'mnist' " + " type: DATA " + " top: 'data' " + " top: 'label' " + " data_param { " + " source: 'mnist-train-leveldb' " + " batch_size: 64 " + " } " + " transform_param { " + " scale: 0.00390625 " + " } " + " include: { phase: TRAIN } " + "} " + "layers { " + " name: 'mnist' " + " type: DATA " + " top: 'data' " + " top: 'label' " + " data_param { " + " source: 'mnist-test-leveldb' " + " batch_size: 100 " + " } " + " transform_param { " + " scale: 0.00390625 " + " } " + " include: { phase: TEST } " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " bottom: 'data' " + " top: 'conv1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " convolution_param { " + " num_output: 20 " + " kernel_size: 5 " + " stride: 1 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'ip1' " + " type: INNER_PRODUCT " + " bottom: 'conv1' " + " top: 'ip1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " inner_product_param { " + " num_output: 10 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'accuracy' " + " type: ACCURACY " + " bottom: 'ip1' " + " bottom: 'label' " + " top: 'accuracy' " + " include: { phase: TEST } " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'ip2' " + " bottom: 'label' " + " top: 'loss' " + "} "; + const string input_proto_train = "state: { phase: TRAIN } " + input_proto; + const string input_proto_test = "state: { phase: TEST } " + input_proto; + const string output_proto_train = + "name: 'LeNet' " + "layers { " + " name: 'mnist' " + " type: DATA " + " top: 'data' " + " top: 'label' " + " data_param { " + " source: 'mnist-train-leveldb' " + " batch_size: 64 " + " } " + " transform_param { " + " scale: 0.00390625 " + " } " + " include: { phase: TRAIN } " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " bottom: 'data' " + " top: 'conv1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " convolution_param { " + " num_output: 20 " + " kernel_size: 5 " + " stride: 1 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'ip1' " + " type: INNER_PRODUCT " + " bottom: 'conv1' " + " top: 'ip1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " inner_product_param { " + " num_output: 10 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'ip2' " + " bottom: 'label' " + " top: 'loss' " + "} "; + const string& output_proto_test = + "name: 'LeNet' " + "layers { " + " name: 'mnist' " + " type: DATA " + " top: 'data' " + " top: 'label' " + " data_param { " + " source: 'mnist-test-leveldb' " + " batch_size: 100 " + " } " + " transform_param { " + " scale: 0.00390625 " + " } " + " include: { phase: TEST } " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " bottom: 'data' " + " top: 'conv1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " convolution_param { " + " num_output: 20 " + " kernel_size: 5 " + " stride: 1 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'ip1' " + " type: INNER_PRODUCT " + " bottom: 'conv1' " + " top: 'ip1' " + " blobs_lr: 1 " + " blobs_lr: 2 " + " inner_product_param { " + " num_output: 10 " + " weight_filler { " + " type: 'xavier' " + " } " + " bias_filler { " + " type: 'constant' " + " } " + " } " + "} " + "layers { " + " name: 'accuracy' " + " type: ACCURACY " + " bottom: 'ip1' " + " bottom: 'label' " + " top: 'accuracy' " + " include: { phase: TEST } " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'ip2' " + " bottom: 'label' " + " top: 'loss' " + "} "; + const string output_proto_train_explicit = + output_proto_train + " state: { phase: TRAIN } "; + const string output_proto_test_explicit = + output_proto_test + " state: { phase: TEST } "; + this->RunFilterNetTest(input_proto_train, output_proto_train_explicit); + this->RunFilterNetTest(input_proto_test, output_proto_test_explicit); + + // Also check that nets are filtered according to the Caffe singleton phase, + // if not explicitly specified in the input proto. + Caffe::set_phase(Caffe::TRAIN); + this->RunFilterNetTest(input_proto, output_proto_train); + Caffe::set_phase(Caffe::TEST); + this->RunFilterNetTest(input_proto, output_proto_test); + + // Finally, check that the current Caffe singleton phase is ignored if the + // phase is explicitly specified in the input proto. + Caffe::set_phase(Caffe::TEST); + this->RunFilterNetTest(input_proto_train, output_proto_train_explicit); + Caffe::set_phase(Caffe::TRAIN); + this->RunFilterNetTest(input_proto_test, output_proto_test_explicit); +} + +TEST_F(FilterNetTest, TestFilterOutByStage) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + " include: { stage: 'mystage' } " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + const string& output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterOutByStage2) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { stage: 'mystage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + const string& output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterInByStage) { + const string& input_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { stage: 'mystage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByStage2) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " exclude: { stage: 'mystage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterOutByMultipleStage) { + const string& input_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { stage: 'mystage' stage: 'myotherstage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { stage: 'mystage' } " + "} "; + const string& output_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { stage: 'mystage' } " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMultipleStage) { + const string& input_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { stage: 'myotherstage' } " + " include: { stage: 'mystage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { stage: 'mystage' } " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMultipleStage2) { + const string& input_proto = + "state: { stage: 'mystage' stage: 'myotherstage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { stage: 'mystage' stage: 'myotherstage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { stage: 'mystage' } " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByNotStage) { + const string& input_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { not_stage: 'myotherstage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { not_stage: 'myotherstage' } " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterOutByNotStage) { + const string& input_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { not_stage: 'mystage' } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { not_stage: 'mystage' } " + "} "; + const string& output_proto = + "state: { stage: 'mystage' } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterOutByMinLevel) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 3 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + const string& output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterOutByMaxLevel) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { max_level: -3 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + const string& output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, output_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMinLevel) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 0 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMinLevel2) { + const string& input_proto = + "state: { level: 7 } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 3 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMaxLevel) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { max_level: 0 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInByMaxLevel2) { + const string& input_proto = + "state: { level: -7 } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { max_level: -3 } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunFilterNetTest(input_proto, input_proto); +} + +TEST_F(FilterNetTest, TestFilterInOutByIncludeMultiRule) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 2 phase: TRAIN } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { min_level: 2 phase: TEST } " + "} "; + const string& input_proto_train = + "state: { level: 4 phase: TRAIN } " + input_proto; + const string& input_proto_test = + "state: { level: 4 phase: TEST } " + input_proto; + const string& output_proto_train = + "state: { level: 4 phase: TRAIN } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 2 phase: TRAIN } " + "} "; + const string& output_proto_test = + "state: { level: 4 phase: TEST } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { min_level: 2 phase: TEST } " + "} "; + this->RunFilterNetTest(input_proto_train, output_proto_train); + this->RunFilterNetTest(input_proto_test, output_proto_test); +} + +TEST_F(FilterNetTest, TestFilterInByIncludeMultiRule) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " include: { min_level: 2 phase: TRAIN } " + " include: { phase: TEST } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { min_level: 2 phase: TEST } " + " include: { phase: TRAIN } " + "} "; + const string& input_proto_train = + "state: { level: 2 phase: TRAIN } " + input_proto; + const string& input_proto_test = + "state: { level: 2 phase: TEST } " + input_proto; + this->RunFilterNetTest(input_proto_train, input_proto_train); + this->RunFilterNetTest(input_proto_test, input_proto_test); +} + +TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " exclude: { min_level: 2 phase: TRAIN } " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " exclude: { min_level: 2 phase: TEST } " + "} "; + const string& input_proto_train = + "state: { level: 4 phase: TRAIN } " + input_proto; + const string& input_proto_test = + "state: { level: 4 phase: TEST } " + input_proto; + const string& output_proto_train = + "state: { level: 4 phase: TRAIN } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " exclude: { min_level: 2 phase: TEST } " + "} "; + const string& output_proto_test = + "state: { level: 4 phase: TEST } " + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + " exclude: { min_level: 2 phase: TRAIN } " + "} "; + this->RunFilterNetTest(input_proto_train, output_proto_train); + this->RunFilterNetTest(input_proto_test, output_proto_test); +} + +TYPED_TEST(NetTest, TestReshape) { + typedef typename TypeParam::Dtype Dtype; + // We set up bottom blobs of two different sizes, switch between + // them, and check that forward and backward both run and the results + // are the same. + Caffe::set_random_seed(this->seed_); + Caffe::set_mode(Caffe::CPU); + FillerParameter filler_param; + filler_param.set_std(1); + GaussianFiller filler(filler_param); + Blob blob1(4, 3, 9, 11); + Blob blob2(2, 3, 12, 10); + filler.Fill(&blob1); + filler.Fill(&blob2); + + this->InitReshapableNet(); + Blob* input_blob = this->net_->input_blobs()[0]; + Blob* output_blob = this->net_->output_blobs()[0]; + input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), + blob1.width()); + caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + // call backward just to make sure it runs + this->net_->Backward(); + Blob output1(output_blob->num(), output_blob->channels(), + output_blob->height(), output_blob->width()); + caffe_copy(output1.count(), output_blob->cpu_data(), + output1.mutable_cpu_data()); + + input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), + blob2.width()); + caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + Blob output2(output_blob->num(), output_blob->channels(), + output_blob->height(), output_blob->width()); + caffe_copy(output2.count(), output_blob->cpu_data(), + output2.mutable_cpu_data()); + + input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), + blob1.width()); + caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + for (int i = 0; i < output1.count(); ++i) { + CHECK_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i)); + } + + input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), + blob2.width()); + caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); + this->net_->ForwardPrefilled(); + this->net_->Backward(); + for (int i = 0; i < output2.count(); ++i) { + CHECK_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i)); + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_neuron_layer.cpp b/caffe-crfrnn/src/caffe/test/test_neuron_layer.cpp new file mode 100644 index 00000000..b19a5abd --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_neuron_layer.cpp @@ -0,0 +1,533 @@ +#include +#include + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class NeuronLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + NeuronLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + + void TestDropoutForward(const float dropout_ratio) { + LayerParameter layer_param; + // Fill in the given dropout_ratio, unless it's 0.5, in which case we don't + // set it explicitly to test that 0.5 is the default. + if (dropout_ratio != 0.5) { + layer_param.mutable_dropout_param()->set_dropout_ratio(dropout_ratio); + } + Caffe::set_phase(Caffe::TRAIN); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + float scale = 1. / (1. - layer_param.dropout_param().dropout_ratio()); + const int count = this->blob_bottom_->count(); + // Initialize num_kept to count the number of inputs NOT dropped out. + int num_kept = 0; + for (int i = 0; i < count; ++i) { + if (top_data[i] != 0) { + ++num_kept; + EXPECT_EQ(top_data[i], bottom_data[i] * scale); + } + } + const Dtype std_error = sqrt(dropout_ratio * (1 - dropout_ratio) / count); + // Fail if the number dropped was more than 1.96 * std_error away from the + // expected number -- requires 95% confidence that the dropout layer is not + // obeying the given dropout_ratio for test failure. + const Dtype empirical_dropout_ratio = 1 - num_kept / Dtype(count); + EXPECT_NEAR(empirical_dropout_ratio, dropout_ratio, 1.96 * std_error); + } + + void TestExpForward(const float base, const float scale, const float shift) { + LayerParameter layer_param; + layer_param.mutable_exp_param()->set_base(base); + layer_param.mutable_exp_param()->set_scale(scale); + layer_param.mutable_exp_param()->set_shift(shift); + ExpLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + layer.Forward(blob_bottom_vec_, blob_top_vec_); + const Dtype kDelta = 2e-4; + const Dtype* bottom_data = blob_bottom_->cpu_data(); + const Dtype* top_data = blob_top_->cpu_data(); + for (int i = 0; i < blob_bottom_->count(); ++i) { + const Dtype bottom_val = bottom_data[i]; + const Dtype top_val = top_data[i]; + if (base == -1) { + EXPECT_NEAR(top_val, exp(shift + scale * bottom_val), kDelta); + } else { + EXPECT_NEAR(top_val, pow(base, shift + scale * bottom_val), kDelta); + } + } + } + + void TestExpGradient(const float base, const float scale, const float shift) { + LayerParameter layer_param; + layer_param.mutable_exp_param()->set_base(base); + layer_param.mutable_exp_param()->set_scale(scale); + layer_param.mutable_exp_param()->set_shift(shift); + ExpLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, blob_bottom_vec_, blob_top_vec_); + } +}; + +TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices); + +TYPED_TEST(NeuronLayerTest, TestAbsVal) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + AbsValLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const int count = this->blob_bottom_->count(); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(top_data[i], fabs(bottom_data[i])); + } +} + +TYPED_TEST(NeuronLayerTest, TestAbsGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + AbsValLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestReLU) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]); + } +} + +TYPED_TEST(NeuronLayerTest, TestReLUGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ReLULayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "relu_param { negative_slope: 0.01 }", &layer_param)); + ReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] >= 0) { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i]); + } else { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] * 0.01); + } + } +} + +TYPED_TEST(NeuronLayerTest, TestReLUGradientWithNegativeSlope) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "relu_param { negative_slope: 0.01 }", &layer_param)); + ReLULayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestSigmoid) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SigmoidLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_FLOAT_EQ(top_data[i], 1. / (1 + exp(-bottom_data[i]))); + // check that we squashed the value between 0 and 1 + EXPECT_GE(top_data[i], 0.); + EXPECT_LE(top_data[i], 1.); + } +} + +TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SigmoidLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestTanH) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + TanHLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test exact values + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); + EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); + } + } + } + } +} + +TYPED_TEST(NeuronLayerTest, TestTanHGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + TanHLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestExpLayer) { + typedef typename TypeParam::Dtype Dtype; + // Test default base of "-1" -- should actually set base := e. + const Dtype kBase = -1; + const Dtype kScale = 1; + const Dtype kShift = 0; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradient) { + typedef typename TypeParam::Dtype Dtype; + // Test default base of "-1" -- should actually set base := e. + const Dtype kBase = -1; + const Dtype kScale = 1; + const Dtype kShift = 0; + this->TestExpGradient(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpLayerBase2) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 1; + const Dtype kShift = 0; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradientBase2) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 1; + const Dtype kShift = 0; + this->TestExpGradient(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpLayerBase2Shift1) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 1; + const Dtype kShift = 1; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradientBase2Shift1) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 1; + const Dtype kShift = 1; + this->TestExpGradient(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpLayerBase2Scale3) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 3; + const Dtype kShift = 0; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradientBase2Scale3) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 3; + const Dtype kShift = 0; + this->TestExpGradient(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpLayerBase2Shift1Scale3) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 3; + const Dtype kShift = 1; + this->TestExpForward(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestExpGradientBase2Shift1Scale3) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kBase = 2; + const Dtype kScale = 3; + const Dtype kShift = 1; + this->TestExpGradient(kBase, kScale, kShift); +} + +TYPED_TEST(NeuronLayerTest, TestDropoutHalf) { + const float kDropoutRatio = 0.5; + this->TestDropoutForward(kDropoutRatio); +} + +TYPED_TEST(NeuronLayerTest, TestDropoutThreeQuarters) { + const float kDropoutRatio = 0.75; + this->TestDropoutForward(kDropoutRatio); +} + +TYPED_TEST(NeuronLayerTest, TestDropoutTestPhase) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + Caffe::set_phase(Caffe::TEST); + DropoutLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] != 0) { + EXPECT_EQ(top_data[i], bottom_data[i]); + } + } +} + +TYPED_TEST(NeuronLayerTest, TestDropoutGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + Caffe::set_phase(Caffe::TRAIN); + DropoutLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestDropoutGradientTest) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + Caffe::set_phase(Caffe::TEST); + DropoutLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(NeuronLayerTest, TestBNLL) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + BNLLLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_GE(top_data[i], bottom_data[i]); + } +} + +TYPED_TEST(NeuronLayerTest, TestBNLLGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + BNLLLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#ifdef USE_CUDNN +template +class CuDNNNeuronLayerTest : public ::testing::Test { + protected: + CuDNNNeuronLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~CuDNNNeuronLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(CuDNNNeuronLayerTest, TestDtypes); + +TYPED_TEST(CuDNNNeuronLayerTest, TestReLUCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]); + } +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNReLULayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestReLUWithNegativeSlopeCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "relu_param { negative_slope: 0.01 }", &layer_param)); + CuDNNReLULayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (top_data[i] >= 0) { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i]); + } else { + EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] * 0.01); + } + } +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestReLUGradientWithNegativeSlopeCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + "relu_param { negative_slope: 0.01 }", &layer_param)); + CuDNNReLULayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNSigmoidLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_FLOAT_EQ(top_data[i], 1. / (1 + exp(-bottom_data[i]))); + // check that we squashed the value between 0 and 1 + EXPECT_GE(top_data[i], 0.); + EXPECT_LE(top_data[i], 1.); + } +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestSigmoidGradientCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNSigmoidLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestTanHCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNTanHLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test exact values + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); + EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) / + (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1)); + } + } + } + } +} + +TYPED_TEST(CuDNNNeuronLayerTest, TestTanHGradientCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNTanHLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_platform.cpp b/caffe-crfrnn/src/caffe/test/test_platform.cpp new file mode 100644 index 00000000..f3513e08 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_platform.cpp @@ -0,0 +1,57 @@ +#ifndef CPU_ONLY + +#include +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; + +class PlatformTest : public ::testing::Test {}; + +TEST_F(PlatformTest, TestInitialization) { + printf("Major revision number: %d\n", CAFFE_TEST_CUDA_PROP.major); + printf("Minor revision number: %d\n", CAFFE_TEST_CUDA_PROP.minor); + printf("Name: %s\n", CAFFE_TEST_CUDA_PROP.name); + printf("Total global memory: %lu\n", + CAFFE_TEST_CUDA_PROP.totalGlobalMem); + printf("Total shared memory per block: %lu\n", + CAFFE_TEST_CUDA_PROP.sharedMemPerBlock); + printf("Total registers per block: %d\n", + CAFFE_TEST_CUDA_PROP.regsPerBlock); + printf("Warp size: %d\n", + CAFFE_TEST_CUDA_PROP.warpSize); + printf("Maximum memory pitch: %lu\n", + CAFFE_TEST_CUDA_PROP.memPitch); + printf("Maximum threads per block: %d\n", + CAFFE_TEST_CUDA_PROP.maxThreadsPerBlock); + for (int i = 0; i < 3; ++i) + printf("Maximum dimension %d of block: %d\n", i, + CAFFE_TEST_CUDA_PROP.maxThreadsDim[i]); + for (int i = 0; i < 3; ++i) + printf("Maximum dimension %d of grid: %d\n", i, + CAFFE_TEST_CUDA_PROP.maxGridSize[i]); + printf("Clock rate: %d\n", CAFFE_TEST_CUDA_PROP.clockRate); + printf("Total constant memory: %lu\n", + CAFFE_TEST_CUDA_PROP.totalConstMem); + printf("Texture alignment: %lu\n", + CAFFE_TEST_CUDA_PROP.textureAlignment); + printf("Concurrent copy and execution: %s\n", + (CAFFE_TEST_CUDA_PROP.deviceOverlap ? "Yes" : "No")); + printf("Number of multiprocessors: %d\n", + CAFFE_TEST_CUDA_PROP.multiProcessorCount); + printf("Kernel execution timeout: %s\n", + (CAFFE_TEST_CUDA_PROP.kernelExecTimeoutEnabled ? "Yes" : "No")); + printf("Unified virtual addressing: %s\n", + (CAFFE_TEST_CUDA_PROP.unifiedAddressing ? "Yes" : "No")); + EXPECT_TRUE(true); +} + +} // namespace caffe + +#endif // CPU_ONLY diff --git a/caffe-crfrnn/src/caffe/test/test_pooling_layer.cpp b/caffe-crfrnn/src/caffe/test/test_pooling_layer.cpp new file mode 100644 index 00000000..435caa83 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_pooling_layer.cpp @@ -0,0 +1,1201 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class PoolingLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + PoolingLayerTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()), + blob_top_mask_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701); + blob_bottom_->Reshape(2, 3, 6, 5); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~PoolingLayerTest() { + delete blob_bottom_; + delete blob_top_; + delete blob_top_mask_; + } + Blob* const blob_bottom_; + Blob* const blob_top_; + Blob* const blob_top_mask_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + // Test for 2x 2 square pooling layer + void TestForwardSquare() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 3, 5); + // Input: 2x 2 channels of: + // [1 2 5 2 3] + // [9 4 1 4 8] + // [1 2 5 2 3] + for (int i = 0; i < 15 * num * channels; i += 15) { + blob_bottom_->mutable_cpu_data()[i + 0] = 1; + blob_bottom_->mutable_cpu_data()[i + 1] = 2; + blob_bottom_->mutable_cpu_data()[i + 2] = 5; + blob_bottom_->mutable_cpu_data()[i + 3] = 2; + blob_bottom_->mutable_cpu_data()[i + 4] = 3; + blob_bottom_->mutable_cpu_data()[i + 5] = 9; + blob_bottom_->mutable_cpu_data()[i + 6] = 4; + blob_bottom_->mutable_cpu_data()[i + 7] = 1; + blob_bottom_->mutable_cpu_data()[i + 8] = 4; + blob_bottom_->mutable_cpu_data()[i + 9] = 8; + blob_bottom_->mutable_cpu_data()[i + 10] = 1; + blob_bottom_->mutable_cpu_data()[i + 11] = 2; + blob_bottom_->mutable_cpu_data()[i + 12] = 5; + blob_bottom_->mutable_cpu_data()[i + 13] = 2; + blob_bottom_->mutable_cpu_data()[i + 14] = 3; + } + PoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 2); + EXPECT_EQ(blob_top_->width(), 4); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 2); + EXPECT_EQ(blob_top_mask_->width(), 4); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [9 5 5 8] + // [9 5 5 8] + for (int i = 0; i < 8 * num * channels; i += 8) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 9); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 8); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 9); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 8); + } + if (blob_top_vec_.size() > 1) { + // Expected mask output: 2x 2 channels of: + // [5 2 2 9] + // [5 12 12 9] + for (int i = 0; i < 8 * num * channels; i += 8) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 5); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 2); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 2); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 9); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 5); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 9); + } + } + } + // Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w + void TestForwardRectHigh() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(3); + pooling_param->set_kernel_w(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 6, 6); + // Input: 2x 2 channels of: + // [35 1 6 26 19 24] + // [ 3 32 7 21 23 25] + // [31 9 2 22 27 20] + // [ 8 28 33 17 10 15] + // [30 5 34 12 14 16] + // [ 4 36 29 13 18 11] + // (this is generated by magic(6) in MATLAB) + for (int i = 0; i < 36 * num * channels; i += 36) { + blob_bottom_->mutable_cpu_data()[i + 0] = 35; + blob_bottom_->mutable_cpu_data()[i + 1] = 1; + blob_bottom_->mutable_cpu_data()[i + 2] = 6; + blob_bottom_->mutable_cpu_data()[i + 3] = 26; + blob_bottom_->mutable_cpu_data()[i + 4] = 19; + blob_bottom_->mutable_cpu_data()[i + 5] = 24; + blob_bottom_->mutable_cpu_data()[i + 6] = 3; + blob_bottom_->mutable_cpu_data()[i + 7] = 32; + blob_bottom_->mutable_cpu_data()[i + 8] = 7; + blob_bottom_->mutable_cpu_data()[i + 9] = 21; + blob_bottom_->mutable_cpu_data()[i + 10] = 23; + blob_bottom_->mutable_cpu_data()[i + 11] = 25; + blob_bottom_->mutable_cpu_data()[i + 12] = 31; + blob_bottom_->mutable_cpu_data()[i + 13] = 9; + blob_bottom_->mutable_cpu_data()[i + 14] = 2; + blob_bottom_->mutable_cpu_data()[i + 15] = 22; + blob_bottom_->mutable_cpu_data()[i + 16] = 27; + blob_bottom_->mutable_cpu_data()[i + 17] = 20; + blob_bottom_->mutable_cpu_data()[i + 18] = 8; + blob_bottom_->mutable_cpu_data()[i + 19] = 28; + blob_bottom_->mutable_cpu_data()[i + 20] = 33; + blob_bottom_->mutable_cpu_data()[i + 21] = 17; + blob_bottom_->mutable_cpu_data()[i + 22] = 10; + blob_bottom_->mutable_cpu_data()[i + 23] = 15; + blob_bottom_->mutable_cpu_data()[i + 24] = 30; + blob_bottom_->mutable_cpu_data()[i + 25] = 5; + blob_bottom_->mutable_cpu_data()[i + 26] = 34; + blob_bottom_->mutable_cpu_data()[i + 27] = 12; + blob_bottom_->mutable_cpu_data()[i + 28] = 14; + blob_bottom_->mutable_cpu_data()[i + 29] = 16; + blob_bottom_->mutable_cpu_data()[i + 30] = 4; + blob_bottom_->mutable_cpu_data()[i + 31] = 36; + blob_bottom_->mutable_cpu_data()[i + 32] = 29; + blob_bottom_->mutable_cpu_data()[i + 33] = 13; + blob_bottom_->mutable_cpu_data()[i + 34] = 18; + blob_bottom_->mutable_cpu_data()[i + 35] = 11; + } + PoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 4); + EXPECT_EQ(blob_top_->width(), 5); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 4); + EXPECT_EQ(blob_top_mask_->width(), 5); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [35 32 26 27 27] + // [32 33 33 27 27] + // [31 34 34 27 27] + // [36 36 34 18 18] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 8], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 9], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31); + EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18); + EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18); + } + if (blob_top_vec_.size() > 1) { + // [ 1 8 4 17 17] + // [ 8 21 21 17 17] + // [13 27 27 17 17] + // [32 32 27 35 35] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34); + } + } + } + // Test for rectangular pooling layer with kernel_w > kernel_h + void TestForwardRectWide() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(2); + pooling_param->set_kernel_w(3); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 6, 6); + // Input: 2x 2 channels of: + // [35 1 6 26 19 24] + // [ 3 32 7 21 23 25] + // [31 9 2 22 27 20] + // [ 8 28 33 17 10 15] + // [30 5 34 12 14 16] + // [ 4 36 29 13 18 11] + // (this is generated by magic(6) in MATLAB) + for (int i = 0; i < 36 * num * channels; i += 36) { + blob_bottom_->mutable_cpu_data()[i + 0] = 35; + blob_bottom_->mutable_cpu_data()[i + 1] = 1; + blob_bottom_->mutable_cpu_data()[i + 2] = 6; + blob_bottom_->mutable_cpu_data()[i + 3] = 26; + blob_bottom_->mutable_cpu_data()[i + 4] = 19; + blob_bottom_->mutable_cpu_data()[i + 5] = 24; + blob_bottom_->mutable_cpu_data()[i + 6] = 3; + blob_bottom_->mutable_cpu_data()[i + 7] = 32; + blob_bottom_->mutable_cpu_data()[i + 8] = 7; + blob_bottom_->mutable_cpu_data()[i + 9] = 21; + blob_bottom_->mutable_cpu_data()[i + 10] = 23; + blob_bottom_->mutable_cpu_data()[i + 11] = 25; + blob_bottom_->mutable_cpu_data()[i + 12] = 31; + blob_bottom_->mutable_cpu_data()[i + 13] = 9; + blob_bottom_->mutable_cpu_data()[i + 14] = 2; + blob_bottom_->mutable_cpu_data()[i + 15] = 22; + blob_bottom_->mutable_cpu_data()[i + 16] = 27; + blob_bottom_->mutable_cpu_data()[i + 17] = 20; + blob_bottom_->mutable_cpu_data()[i + 18] = 8; + blob_bottom_->mutable_cpu_data()[i + 19] = 28; + blob_bottom_->mutable_cpu_data()[i + 20] = 33; + blob_bottom_->mutable_cpu_data()[i + 21] = 17; + blob_bottom_->mutable_cpu_data()[i + 22] = 10; + blob_bottom_->mutable_cpu_data()[i + 23] = 15; + blob_bottom_->mutable_cpu_data()[i + 24] = 30; + blob_bottom_->mutable_cpu_data()[i + 25] = 5; + blob_bottom_->mutable_cpu_data()[i + 26] = 34; + blob_bottom_->mutable_cpu_data()[i + 27] = 12; + blob_bottom_->mutable_cpu_data()[i + 28] = 14; + blob_bottom_->mutable_cpu_data()[i + 29] = 16; + blob_bottom_->mutable_cpu_data()[i + 30] = 4; + blob_bottom_->mutable_cpu_data()[i + 31] = 36; + blob_bottom_->mutable_cpu_data()[i + 32] = 29; + blob_bottom_->mutable_cpu_data()[i + 33] = 13; + blob_bottom_->mutable_cpu_data()[i + 34] = 18; + blob_bottom_->mutable_cpu_data()[i + 35] = 11; + } + PoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 5); + EXPECT_EQ(blob_top_->width(), 4); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 5); + EXPECT_EQ(blob_top_mask_->width(), 4); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [35 32 26 26] + // [32 32 27 27] + // [33 33 33 27] + // [34 34 34 17] + // [36 36 34 18] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 8], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 9], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17); + EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18); + } + if (blob_top_vec_.size() > 1) { + // [ 1 8 4 4] + // [ 8 8 17 17] + // [21 21 21 17] + // [27 27 27 22] + // [32 32 27 35] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34); + } + } + } +}; + +TYPED_TEST_CASE(PoolingLayerTest, TestDtypesAndDevices); + +TYPED_TEST(PoolingLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 2); +} + +TYPED_TEST(PoolingLayerTest, TestSetupPadded) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pad(1); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 4); + EXPECT_EQ(this->blob_top_->width(), 3); +} + +TYPED_TEST(PoolingLayerTest, TestSetupGlobalPooling) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_global_pooling(true); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); +} + +/* +TYPED_TEST(PoolingLayerTest, PrintBackward) { + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_MAX); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom data " << i << " " << this->blob_bottom_->cpu_data()[i] << endl; + } + for (int i = 0; i < this->blob_top_->count(); ++i) { + cout << "top data " << i << " " << this->blob_top_->cpu_data()[i] << endl; + } + + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = i; + } + layer.Backward(this->blob_top_vec_, true, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom diff " << i << " " << this->blob_bottom_->cpu_diff()[i] << endl; + } +} +*/ + +TYPED_TEST(PoolingLayerTest, TestForwardMax) { + this->TestForwardSquare(); + this->TestForwardRectHigh(); + this->TestForwardRectWide(); +} + +TYPED_TEST(PoolingLayerTest, TestForwardMaxTopMask) { + this->blob_top_vec_.push_back(this->blob_top_mask_); + this->TestForwardSquare(); + this->TestForwardRectHigh(); + this->TestForwardRectWide(); +} + +TYPED_TEST(PoolingLayerTest, TestGradientMax) { + typedef typename TypeParam::Dtype Dtype; + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pad(1); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} + +TYPED_TEST(PoolingLayerTest, TestForwardMaxPadded) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pad(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + this->blob_bottom_->Reshape(1, 1, 3, 3); + // Input: + // [ 1 2 4 ] + // [ 2 3 2 ] + // [ 4 2 1 ] + this->blob_bottom_->mutable_cpu_data()[0] = 1; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 4; + this->blob_bottom_->mutable_cpu_data()[3] = 2; + this->blob_bottom_->mutable_cpu_data()[4] = 3; + this->blob_bottom_->mutable_cpu_data()[5] = 2; + this->blob_bottom_->mutable_cpu_data()[6] = 4; + this->blob_bottom_->mutable_cpu_data()[7] = 2; + this->blob_bottom_->mutable_cpu_data()[8] = 1; + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 3); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Dtype epsilon = 1e-8; + // Output: + // [ 1 4 4 ] + // [ 4 4 4 ] + // [ 4 4 1 ] + EXPECT_NEAR(this->blob_top_->cpu_data()[0], 1, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[2], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[4], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[6], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon); +} + +TYPED_TEST(PoolingLayerTest, TestGradientMaxTopMask) { + typedef typename TypeParam::Dtype Dtype; + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + this->blob_top_vec_.push_back(this->blob_top_mask_); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + this->blob_top_vec_.pop_back(); + } + } +} + +TYPED_TEST(PoolingLayerTest, TestForwardAve) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(1); + pooling_param->set_pad(1); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + this->blob_bottom_->Reshape(1, 1, 3, 3); + FillerParameter filler_param; + filler_param.set_value(Dtype(2)); + ConstantFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 3); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Dtype epsilon = 1e-5; + EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0 , epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon); +} + +TYPED_TEST(PoolingLayerTest, TestGradientAve) { + typedef typename TypeParam::Dtype Dtype; + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} + +TYPED_TEST(PoolingLayerTest, TestGradientAvePadded) { + typedef typename TypeParam::Dtype Dtype; + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pad(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} + +#ifdef USE_CUDNN +template +class CuDNNPoolingLayerTest : public ::testing::Test { + protected: + CuDNNPoolingLayerTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()), + blob_top_mask_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701); + blob_bottom_->Reshape(2, 3, 6, 5); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~CuDNNPoolingLayerTest() { + delete blob_bottom_; + delete blob_top_; + delete blob_top_mask_; + } + Blob* const blob_bottom_; + Blob* const blob_top_; + Blob* const blob_top_mask_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + // Test for 2x 2 square pooling layer + void TestForwardSquare() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 3, 5); + // Input: 2x 2 channels of: + // [1 2 5 2 3] + // [9 4 1 4 8] + // [1 2 5 2 3] + for (int i = 0; i < 15 * num * channels; i += 15) { + blob_bottom_->mutable_cpu_data()[i + 0] = 1; + blob_bottom_->mutable_cpu_data()[i + 1] = 2; + blob_bottom_->mutable_cpu_data()[i + 2] = 5; + blob_bottom_->mutable_cpu_data()[i + 3] = 2; + blob_bottom_->mutable_cpu_data()[i + 4] = 3; + blob_bottom_->mutable_cpu_data()[i + 5] = 9; + blob_bottom_->mutable_cpu_data()[i + 6] = 4; + blob_bottom_->mutable_cpu_data()[i + 7] = 1; + blob_bottom_->mutable_cpu_data()[i + 8] = 4; + blob_bottom_->mutable_cpu_data()[i + 9] = 8; + blob_bottom_->mutable_cpu_data()[i + 10] = 1; + blob_bottom_->mutable_cpu_data()[i + 11] = 2; + blob_bottom_->mutable_cpu_data()[i + 12] = 5; + blob_bottom_->mutable_cpu_data()[i + 13] = 2; + blob_bottom_->mutable_cpu_data()[i + 14] = 3; + } + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 2); + EXPECT_EQ(blob_top_->width(), 4); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 2); + EXPECT_EQ(blob_top_mask_->width(), 4); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [9 5 5 8] + // [9 5 5 8] + for (int i = 0; i < 8 * num * channels; i += 8) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 9); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 8); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 9); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 5); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 8); + } + if (blob_top_vec_.size() > 1) { + // Expected mask output: 2x 2 channels of: + // [5 2 2 9] + // [5 12 12 9] + for (int i = 0; i < 8 * num * channels; i += 8) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 5); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 2); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 2); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 9); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 5); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 9); + } + } + } + // Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w + void TestForwardRectHigh() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(3); + pooling_param->set_kernel_w(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 6, 6); + // Input: 2x 2 channels of: + // [35 1 6 26 19 24] + // [ 3 32 7 21 23 25] + // [31 9 2 22 27 20] + // [ 8 28 33 17 10 15] + // [30 5 34 12 14 16] + // [ 4 36 29 13 18 11] + // (this is generated by magic(6) in MATLAB) + for (int i = 0; i < 36 * num * channels; i += 36) { + blob_bottom_->mutable_cpu_data()[i + 0] = 35; + blob_bottom_->mutable_cpu_data()[i + 1] = 1; + blob_bottom_->mutable_cpu_data()[i + 2] = 6; + blob_bottom_->mutable_cpu_data()[i + 3] = 26; + blob_bottom_->mutable_cpu_data()[i + 4] = 19; + blob_bottom_->mutable_cpu_data()[i + 5] = 24; + blob_bottom_->mutable_cpu_data()[i + 6] = 3; + blob_bottom_->mutable_cpu_data()[i + 7] = 32; + blob_bottom_->mutable_cpu_data()[i + 8] = 7; + blob_bottom_->mutable_cpu_data()[i + 9] = 21; + blob_bottom_->mutable_cpu_data()[i + 10] = 23; + blob_bottom_->mutable_cpu_data()[i + 11] = 25; + blob_bottom_->mutable_cpu_data()[i + 12] = 31; + blob_bottom_->mutable_cpu_data()[i + 13] = 9; + blob_bottom_->mutable_cpu_data()[i + 14] = 2; + blob_bottom_->mutable_cpu_data()[i + 15] = 22; + blob_bottom_->mutable_cpu_data()[i + 16] = 27; + blob_bottom_->mutable_cpu_data()[i + 17] = 20; + blob_bottom_->mutable_cpu_data()[i + 18] = 8; + blob_bottom_->mutable_cpu_data()[i + 19] = 28; + blob_bottom_->mutable_cpu_data()[i + 20] = 33; + blob_bottom_->mutable_cpu_data()[i + 21] = 17; + blob_bottom_->mutable_cpu_data()[i + 22] = 10; + blob_bottom_->mutable_cpu_data()[i + 23] = 15; + blob_bottom_->mutable_cpu_data()[i + 24] = 30; + blob_bottom_->mutable_cpu_data()[i + 25] = 5; + blob_bottom_->mutable_cpu_data()[i + 26] = 34; + blob_bottom_->mutable_cpu_data()[i + 27] = 12; + blob_bottom_->mutable_cpu_data()[i + 28] = 14; + blob_bottom_->mutable_cpu_data()[i + 29] = 16; + blob_bottom_->mutable_cpu_data()[i + 30] = 4; + blob_bottom_->mutable_cpu_data()[i + 31] = 36; + blob_bottom_->mutable_cpu_data()[i + 32] = 29; + blob_bottom_->mutable_cpu_data()[i + 33] = 13; + blob_bottom_->mutable_cpu_data()[i + 34] = 18; + blob_bottom_->mutable_cpu_data()[i + 35] = 11; + } + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 4); + EXPECT_EQ(blob_top_->width(), 5); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 4); + EXPECT_EQ(blob_top_mask_->width(), 5); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [35 32 26 27 27] + // [32 33 33 27 27] + // [31 34 34 27 27] + // [36 36 34 18 18] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 8], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 9], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31); + EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18); + EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18); + } + if (blob_top_vec_.size() > 1) { + // [ 1 8 4 17 17] + // [ 8 21 21 17 17] + // [13 27 27 17 17] + // [32 32 27 35 35] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34); + } + } + } + // Test for rectangular pooling layer with kernel_w > kernel_h + void TestForwardRectWide() { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(2); + pooling_param->set_kernel_w(3); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + const int num = 2; + const int channels = 2; + blob_bottom_->Reshape(num, channels, 6, 6); + // Input: 2x 2 channels of: + // [35 1 6 26 19 24] + // [ 3 32 7 21 23 25] + // [31 9 2 22 27 20] + // [ 8 28 33 17 10 15] + // [30 5 34 12 14 16] + // [ 4 36 29 13 18 11] + // (this is generated by magic(6) in MATLAB) + for (int i = 0; i < 36 * num * channels; i += 36) { + blob_bottom_->mutable_cpu_data()[i + 0] = 35; + blob_bottom_->mutable_cpu_data()[i + 1] = 1; + blob_bottom_->mutable_cpu_data()[i + 2] = 6; + blob_bottom_->mutable_cpu_data()[i + 3] = 26; + blob_bottom_->mutable_cpu_data()[i + 4] = 19; + blob_bottom_->mutable_cpu_data()[i + 5] = 24; + blob_bottom_->mutable_cpu_data()[i + 6] = 3; + blob_bottom_->mutable_cpu_data()[i + 7] = 32; + blob_bottom_->mutable_cpu_data()[i + 8] = 7; + blob_bottom_->mutable_cpu_data()[i + 9] = 21; + blob_bottom_->mutable_cpu_data()[i + 10] = 23; + blob_bottom_->mutable_cpu_data()[i + 11] = 25; + blob_bottom_->mutable_cpu_data()[i + 12] = 31; + blob_bottom_->mutable_cpu_data()[i + 13] = 9; + blob_bottom_->mutable_cpu_data()[i + 14] = 2; + blob_bottom_->mutable_cpu_data()[i + 15] = 22; + blob_bottom_->mutable_cpu_data()[i + 16] = 27; + blob_bottom_->mutable_cpu_data()[i + 17] = 20; + blob_bottom_->mutable_cpu_data()[i + 18] = 8; + blob_bottom_->mutable_cpu_data()[i + 19] = 28; + blob_bottom_->mutable_cpu_data()[i + 20] = 33; + blob_bottom_->mutable_cpu_data()[i + 21] = 17; + blob_bottom_->mutable_cpu_data()[i + 22] = 10; + blob_bottom_->mutable_cpu_data()[i + 23] = 15; + blob_bottom_->mutable_cpu_data()[i + 24] = 30; + blob_bottom_->mutable_cpu_data()[i + 25] = 5; + blob_bottom_->mutable_cpu_data()[i + 26] = 34; + blob_bottom_->mutable_cpu_data()[i + 27] = 12; + blob_bottom_->mutable_cpu_data()[i + 28] = 14; + blob_bottom_->mutable_cpu_data()[i + 29] = 16; + blob_bottom_->mutable_cpu_data()[i + 30] = 4; + blob_bottom_->mutable_cpu_data()[i + 31] = 36; + blob_bottom_->mutable_cpu_data()[i + 32] = 29; + blob_bottom_->mutable_cpu_data()[i + 33] = 13; + blob_bottom_->mutable_cpu_data()[i + 34] = 18; + blob_bottom_->mutable_cpu_data()[i + 35] = 11; + } + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_->num(), num); + EXPECT_EQ(blob_top_->channels(), channels); + EXPECT_EQ(blob_top_->height(), 5); + EXPECT_EQ(blob_top_->width(), 4); + if (blob_top_vec_.size() > 1) { + EXPECT_EQ(blob_top_mask_->num(), num); + EXPECT_EQ(blob_top_mask_->channels(), channels); + EXPECT_EQ(blob_top_mask_->height(), 5); + EXPECT_EQ(blob_top_mask_->width(), 4); + } + layer.Forward(blob_bottom_vec_, blob_top_vec_); + // Expected output: 2x 2 channels of: + // [35 32 26 26] + // [32 32 27 27] + // [33 33 33 27] + // [34 34 34 17] + // [36 36 34 18] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_->cpu_data()[i + 0], 35); + EXPECT_EQ(blob_top_->cpu_data()[i + 1], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 2], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 3], 26); + EXPECT_EQ(blob_top_->cpu_data()[i + 4], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 5], 32); + EXPECT_EQ(blob_top_->cpu_data()[i + 6], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 7], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 8], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 9], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33); + EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27); + EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17); + EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36); + EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34); + EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18); + } + if (blob_top_vec_.size() > 1) { + // [ 1 8 4 4] + // [ 8 8 17 17] + // [21 21 21 17] + // [27 27 27 22] + // [32 32 27 35] + for (int i = 0; i < 20 * num * channels; i += 20) { + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0], 0); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3], 3); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 7); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 8], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 9], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26); + EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34); + } + } + } +}; + +TYPED_TEST_CASE(CuDNNPoolingLayerTest, TestDtypes); + +TYPED_TEST(CuDNNPoolingLayerTest, TestSetupCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 2); +} + +// This test and all following cuDNN pooling tests with padding are commented +// for now, since cuDNN pooling does not currently support padding. +/* +TYPED_TEST(CuDNNPoolingLayerTest, TestSetupPaddedCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pad(1); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 4); + EXPECT_EQ(this->blob_top_->width(), 3); +} +*/ + +/* +TYPED_TEST(CuDNNPoolingLayerTest, PrintBackwardCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + layer_param.set_kernelsize(3); + layer_param.set_stride(2); + layer_param.set_pool(LayerParameter_PoolMethod_MAX); + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom data " << i << " " << this->blob_bottom_->cpu_data()[i] << endl; + } + for (int i = 0; i < this->blob_top_->count(); ++i) { + cout << "top data " << i << " " << this->blob_top_->cpu_data()[i] << endl; + } + + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = i; + } + layer.Backward(this->blob_top_vec_, true, this->blob_bottom_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + cout << "bottom diff " << i << " " << this->blob_bottom_->cpu_diff()[i] << endl; + } +} +*/ + +TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxCuDNN) { + Caffe::set_mode(Caffe::GPU); + this->TestForwardSquare(); + this->TestForwardRectHigh(); + this->TestForwardRectWide(); +} + +// Currently, cuDNN does not support a top mask, so we comment this and +// the corresponding backward test. +/* +TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxTopMaskCuDNN) { + Caffe::set_mode(Caffe::GPU); + this->blob_top_vec_.push_back(this->blob_top_mask_); + this->TestForwardSquare(); + this->TestForwardRectHigh(); + this->TestForwardRectWide(); +} +*/ + +TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxCuDNN) { + Caffe::set_mode(Caffe::GPU); + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + // currenty, cuDNN pooling does not support padding + pooling_param->set_pad(0); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + CuDNNPoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} + +/* +TYPED_TEST(CuDNNPoolingLayerTest, TestForwardMaxPaddedCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pad(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + this->blob_bottom_->Reshape(1, 1, 3, 3); + // Input: + // [ 1 2 4 ] + // [ 2 3 2 ] + // [ 4 2 1 ] + this->blob_bottom_->mutable_cpu_data()[0] = 1; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 4; + this->blob_bottom_->mutable_cpu_data()[3] = 2; + this->blob_bottom_->mutable_cpu_data()[4] = 3; + this->blob_bottom_->mutable_cpu_data()[5] = 2; + this->blob_bottom_->mutable_cpu_data()[6] = 4; + this->blob_bottom_->mutable_cpu_data()[7] = 2; + this->blob_bottom_->mutable_cpu_data()[8] = 1; + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 3); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + TypeParam epsilon = 1e-8; + // Output: + // [ 1 4 4 ] + // [ 4 4 4 ] + // [ 4 4 1 ] + EXPECT_NEAR(this->blob_top_->cpu_data()[0], 1, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[2], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[4], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[6], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon); + EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon); +} +*/ + +/* +TYPED_TEST(CuDNNPoolingLayerTest, TestGradientMaxTopMaskCuDNN) { + Caffe::set_mode(Caffe::GPU); + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_MAX); + this->blob_top_vec_.push_back(this->blob_top_mask_); + CuDNNPoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + this->blob_top_vec_.pop_back(); + } + } +} +*/ + +TYPED_TEST(CuDNNPoolingLayerTest, TestForwardAveCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(1); + // Currently, cuDNN pooling does not support padding, so we use + // a simplified version of this test. + pooling_param->set_pad(0); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + this->blob_bottom_->Reshape(1, 1, 3, 3); + FillerParameter filler_param; + filler_param.set_value(TypeParam(2)); + ConstantFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + CuDNNPoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), 1); + EXPECT_EQ(this->blob_top_->channels(), 1); + EXPECT_EQ(this->blob_top_->height(), 1); + EXPECT_EQ(this->blob_top_->width(), 1); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + TypeParam epsilon = 1e-5; + EXPECT_NEAR(this->blob_top_->cpu_data()[0], 2.0, epsilon); +} + +TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAveCuDNN) { + Caffe::set_mode(Caffe::GPU); + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + CuDNNPoolingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} + +/* +TYPED_TEST(CuDNNPoolingLayerTest, TestGradientAvePaddedCuDNN) { + Caffe::set_mode(Caffe::GPU); + for (int kernel_h = 3; kernel_h <= 4; kernel_h++) { + for (int kernel_w = 3; kernel_w <= 4; kernel_w++) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_h(kernel_h); + pooling_param->set_kernel_w(kernel_w); + pooling_param->set_stride(2); + pooling_param->set_pad(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_AVE); + CuDNNPoolingLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + } +} +*/ + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_power_layer.cpp b/caffe-crfrnn/src/caffe/test/test_power_layer.cpp new file mode 100644 index 00000000..0d52fa1c --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_power_layer.cpp @@ -0,0 +1,170 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class PowerLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + PowerLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~PowerLayerTest() { delete blob_bottom_; delete blob_top_; } + + void TestForward(Dtype power, Dtype scale, Dtype shift) { + LayerParameter layer_param; + layer_param.mutable_power_param()->set_power(power); + layer_param.mutable_power_param()->set_scale(scale); + layer_param.mutable_power_param()->set_shift(shift); + PowerLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const Dtype min_precision = 1e-5; + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + Dtype expected_value = pow(shift + scale * bottom_data[i], power); + if (power == Dtype(0) || power == Dtype(1) || power == Dtype(2)) { + EXPECT_FALSE(isnan(top_data[i])); + } + if (isnan(expected_value)) { + EXPECT_TRUE(isnan(top_data[i])); + } else { + Dtype precision = std::max( + Dtype(std::abs(expected_value * Dtype(1e-4))), min_precision); + EXPECT_NEAR(expected_value, top_data[i], precision); + } + } + } + + void TestBackward(Dtype power, Dtype scale, Dtype shift) { + LayerParameter layer_param; + layer_param.mutable_power_param()->set_power(power); + layer_param.mutable_power_param()->set_scale(scale); + layer_param.mutable_power_param()->set_shift(shift); + PowerLayer layer(layer_param); + if (power != Dtype(0) && power != Dtype(1) && power != Dtype(2)) { + // Avoid NaNs by forcing (shift + scale * x) >= 0 + Dtype* bottom_data = this->blob_bottom_->mutable_cpu_data(); + Dtype min_value = -shift / scale; + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + if (bottom_data[i] < min_value) { + bottom_data[i] = min_value + (min_value - bottom_data[i]); + } + } + } + GradientChecker checker(1e-2, 1e-2, 1701, 0., 0.01); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(PowerLayerTest, TestDtypesAndDevices); + +TYPED_TEST(PowerLayerTest, TestPower) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 0.37; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestForward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerGradient) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 0.37; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestBackward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerGradientShiftZero) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 0.37; + Dtype scale = 0.83; + Dtype shift = 0.0; + this->TestBackward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerZero) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 0.0; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestForward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerZeroGradient) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 0.0; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestBackward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerOne) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 1.0; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestForward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerOneGradient) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 1.0; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestBackward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerTwo) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 2.0; + Dtype scale = 0.34; + Dtype shift = -2.4; + this->TestForward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerTwoGradient) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 2.0; + Dtype scale = 0.83; + Dtype shift = -2.4; + this->TestBackward(power, scale, shift); +} + +TYPED_TEST(PowerLayerTest, TestPowerTwoScaleHalfGradient) { + typedef typename TypeParam::Dtype Dtype; + Dtype power = 2.0; + Dtype scale = 0.5; + Dtype shift = -2.4; + this->TestBackward(power, scale, shift); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_protobuf.cpp b/caffe-crfrnn/src/caffe/test/test_protobuf.cpp new file mode 100644 index 00000000..0c502d6d --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_protobuf.cpp @@ -0,0 +1,29 @@ +// This is simply a script that tries serializing protocol buffer in text +// format. Nothing special here and no actual code is being tested. +#include + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/proto/caffe.pb.h" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +class ProtoTest : public ::testing::Test {}; + +TEST_F(ProtoTest, TestSerialization) { + LayerParameter param; + param.set_name("test"); + param.set_type(LayerParameter_LayerType_NONE); + std::cout << "Printing in binary format." << std::endl; + std::cout << param.SerializeAsString() << std::endl; + std::cout << "Printing in text format." << std::endl; + std::string str; + google::protobuf::TextFormat::PrintToString(param, &str); + std::cout << str << std::endl; + EXPECT_TRUE(true); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_random_number_generator.cpp b/caffe-crfrnn/src/caffe/test/test_random_number_generator.cpp new file mode 100644 index 00000000..98424c06 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_random_number_generator.cpp @@ -0,0 +1,521 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class RandomNumberGeneratorTest : public ::testing::Test { + protected: + RandomNumberGeneratorTest() + : mean_bound_multiplier_(3.8), // ~99.99% confidence for test failure. + sample_size_(10000), + seed_(1701), + data_(new SyncedMemory(sample_size_ * sizeof(Dtype))), + data_2_(new SyncedMemory(sample_size_ * sizeof(Dtype))), + int_data_(new SyncedMemory(sample_size_ * sizeof(int))), + int_data_2_(new SyncedMemory(sample_size_ * sizeof(int))) {} + + virtual void SetUp() { + Caffe::set_random_seed(this->seed_); + } + + Dtype sample_mean(const Dtype* const seqs, const int sample_size) { + Dtype sum = 0; + for (int i = 0; i < sample_size; ++i) { + sum += seqs[i]; + } + return sum / sample_size; + } + + Dtype sample_mean(const Dtype* const seqs) { + return sample_mean(seqs, sample_size_); + } + + Dtype sample_mean(const int* const seqs, const int sample_size) { + Dtype sum = 0; + for (int i = 0; i < sample_size; ++i) { + sum += Dtype(seqs[i]); + } + return sum / sample_size; + } + + Dtype sample_mean(const int* const seqs) { + return sample_mean(seqs, sample_size_); + } + + Dtype mean_bound(const Dtype std, const int sample_size) { + return mean_bound_multiplier_ * std / sqrt(static_cast(sample_size)); + } + + Dtype mean_bound(const Dtype std) { + return mean_bound(std, sample_size_); + } + + void RngGaussianFill(const Dtype mu, const Dtype sigma, void* cpu_data) { + Dtype* rng_data = static_cast(cpu_data); + caffe_rng_gaussian(sample_size_, mu, sigma, rng_data); + } + + void RngGaussianChecks(const Dtype mu, const Dtype sigma, + const void* cpu_data, const Dtype sparse_p = 0) { + const Dtype* rng_data = static_cast(cpu_data); + const Dtype true_mean = mu; + const Dtype true_std = sigma; + // Check that sample mean roughly matches true mean. + const Dtype bound = this->mean_bound(true_std); + const Dtype sample_mean = this->sample_mean( + static_cast(cpu_data)); + EXPECT_NEAR(sample_mean, true_mean, bound); + // Check that roughly half the samples are above the true mean. + int num_above_mean = 0; + int num_below_mean = 0; + int num_mean = 0; + int num_nan = 0; + for (int i = 0; i < sample_size_; ++i) { + if (rng_data[i] > true_mean) { + ++num_above_mean; + } else if (rng_data[i] < true_mean) { + ++num_below_mean; + } else if (rng_data[i] == true_mean) { + ++num_mean; + } else { + ++num_nan; + } + } + EXPECT_EQ(0, num_nan); + if (sparse_p == Dtype(0)) { + EXPECT_EQ(0, num_mean); + } + const Dtype sample_p_above_mean = + static_cast(num_above_mean) / sample_size_; + const Dtype bernoulli_p = (1 - sparse_p) * 0.5; + const Dtype bernoulli_std = sqrt(bernoulli_p * (1 - bernoulli_p)); + const Dtype bernoulli_bound = this->mean_bound(bernoulli_std); + EXPECT_NEAR(bernoulli_p, sample_p_above_mean, bernoulli_bound); + } + + void RngUniformFill(const Dtype lower, const Dtype upper, void* cpu_data) { + CHECK_GE(upper, lower); + Dtype* rng_data = static_cast(cpu_data); + caffe_rng_uniform(sample_size_, lower, upper, rng_data); + } + + void RngUniformChecks(const Dtype lower, const Dtype upper, + const void* cpu_data, const Dtype sparse_p = 0) { + const Dtype* rng_data = static_cast(cpu_data); + const Dtype true_mean = (lower + upper) / 2; + const Dtype true_std = (upper - lower) / sqrt(12); + // Check that sample mean roughly matches true mean. + const Dtype bound = this->mean_bound(true_std); + const Dtype sample_mean = this->sample_mean(rng_data); + EXPECT_NEAR(sample_mean, true_mean, bound); + // Check that roughly half the samples are above the true mean, and none are + // above upper or below lower. + int num_above_mean = 0; + int num_below_mean = 0; + int num_mean = 0; + int num_nan = 0; + int num_above_upper = 0; + int num_below_lower = 0; + for (int i = 0; i < sample_size_; ++i) { + if (rng_data[i] > true_mean) { + ++num_above_mean; + } else if (rng_data[i] < true_mean) { + ++num_below_mean; + } else if (rng_data[i] == true_mean) { + ++num_mean; + } else { + ++num_nan; + } + if (rng_data[i] > upper) { + ++num_above_upper; + } else if (rng_data[i] < lower) { + ++num_below_lower; + } + } + EXPECT_EQ(0, num_nan); + EXPECT_EQ(0, num_above_upper); + EXPECT_EQ(0, num_below_lower); + if (sparse_p == Dtype(0)) { + EXPECT_EQ(0, num_mean); + } + const Dtype sample_p_above_mean = + static_cast(num_above_mean) / sample_size_; + const Dtype bernoulli_p = (1 - sparse_p) * 0.5; + const Dtype bernoulli_std = sqrt(bernoulli_p * (1 - bernoulli_p)); + const Dtype bernoulli_bound = this->mean_bound(bernoulli_std); + EXPECT_NEAR(bernoulli_p, sample_p_above_mean, bernoulli_bound); + } + + void RngBernoulliFill(const Dtype p, void* cpu_data) { + int* rng_data = static_cast(cpu_data); + caffe_rng_bernoulli(sample_size_, p, rng_data); + } + + void RngBernoulliChecks(const Dtype p, const void* cpu_data) { + const int* rng_data = static_cast(cpu_data); + const Dtype true_mean = p; + const Dtype true_std = sqrt(p * (1 - p)); + const Dtype bound = this->mean_bound(true_std); + const Dtype sample_mean = this->sample_mean(rng_data); + EXPECT_NEAR(sample_mean, true_mean, bound); + } + +#ifndef CPU_ONLY + + void RngGaussianFillGPU(const Dtype mu, const Dtype sigma, void* gpu_data) { + Dtype* rng_data = static_cast(gpu_data); + caffe_gpu_rng_gaussian(sample_size_, mu, sigma, rng_data); + } + + void RngUniformFillGPU(const Dtype lower, const Dtype upper, void* gpu_data) { + CHECK_GE(upper, lower); + Dtype* rng_data = static_cast(gpu_data); + caffe_gpu_rng_uniform(sample_size_, lower, upper, rng_data); + } + + // Fills with uniform integers in [0, UINT_MAX] using 2 argument form of + // caffe_gpu_rng_uniform. + void RngUniformIntFillGPU(void* gpu_data) { + unsigned int* rng_data = static_cast(gpu_data); + caffe_gpu_rng_uniform(sample_size_, rng_data); + } + +#endif + + int num_above_mean; + int num_below_mean; + + Dtype mean_bound_multiplier_; + + size_t sample_size_; + uint32_t seed_; + + shared_ptr data_; + shared_ptr data_2_; + shared_ptr int_data_; + shared_ptr int_data_2_; +}; + +TYPED_TEST_CASE(RandomNumberGeneratorTest, TestDtypes); + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian) { + const TypeParam mu = 0; + const TypeParam sigma = 1; + void* gaussian_data = this->data_->mutable_cpu_data(); + this->RngGaussianFill(mu, sigma, gaussian_data); + this->RngGaussianChecks(mu, sigma, gaussian_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian2) { + const TypeParam mu = -2; + const TypeParam sigma = 3; + void* gaussian_data = this->data_->mutable_cpu_data(); + this->RngGaussianFill(mu, sigma, gaussian_data); + this->RngGaussianChecks(mu, sigma, gaussian_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform) { + const TypeParam lower = 0; + const TypeParam upper = 1; + void* uniform_data = this->data_->mutable_cpu_data(); + this->RngUniformFill(lower, upper, uniform_data); + this->RngUniformChecks(lower, upper, uniform_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform2) { + const TypeParam lower = -7.3; + const TypeParam upper = -2.3; + void* uniform_data = this->data_->mutable_cpu_data(); + this->RngUniformFill(lower, upper, uniform_data); + this->RngUniformChecks(lower, upper, uniform_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngBernoulli) { + const TypeParam p = 0.3; + void* bernoulli_data = this->int_data_->mutable_cpu_data(); + this->RngBernoulliFill(p, bernoulli_data); + this->RngBernoulliChecks(p, bernoulli_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngBernoulli2) { + const TypeParam p = 0.9; + void* bernoulli_data = this->int_data_->mutable_cpu_data(); + this->RngBernoulliFill(p, bernoulli_data); + this->RngBernoulliChecks(p, bernoulli_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianTimesGaussian) { + const TypeParam mu = 0; + const TypeParam sigma = 1; + + // Sample from 0 mean Gaussian. + TypeParam* gaussian_data_1 = + static_cast(this->data_->mutable_cpu_data()); + this->RngGaussianFill(mu, sigma, gaussian_data_1); + + // Sample from 0 mean Gaussian again. + TypeParam* gaussian_data_2 = + static_cast(this->data_2_->mutable_cpu_data()); + this->RngGaussianFill(mu, sigma, gaussian_data_2); + + // Multiply Gaussians. + for (int i = 0; i < this->sample_size_; ++i) { + gaussian_data_1[i] *= gaussian_data_2[i]; + } + + // Check that result has mean 0. + TypeParam mu_product = pow(mu, 2); + TypeParam sigma_product = sqrt(pow(sigma, 2) / 2); + this->RngGaussianChecks(mu_product, sigma_product, gaussian_data_1); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniformTimesUniform) { + // Sample from Uniform on [-2, 2]. + const TypeParam lower_1 = -2; + const TypeParam upper_1 = -lower_1; + TypeParam* uniform_data_1 = + static_cast(this->data_->mutable_cpu_data()); + this->RngUniformFill(lower_1, upper_1, uniform_data_1); + + // Sample from Uniform on [-3, 3]. + const TypeParam lower_2 = -3; + const TypeParam upper_2 = -lower_2; + TypeParam* uniform_data_2 = + static_cast(this->data_2_->mutable_cpu_data()); + this->RngUniformFill(lower_2, upper_2, uniform_data_2); + + // Multiply Uniforms. + for (int i = 0; i < this->sample_size_; ++i) { + uniform_data_1[i] *= uniform_data_2[i]; + } + + // Check that result does not violate checked properties of Uniform on [-6, 6] + // (though it is not actually uniformly distributed). + const TypeParam lower_prod = lower_1 * upper_2; + const TypeParam upper_prod = -lower_prod; + this->RngUniformChecks(lower_prod, upper_prod, uniform_data_1); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianTimesBernoulli) { + // Sample from 0 mean Gaussian. + const TypeParam mu = 0; + const TypeParam sigma = 1; + TypeParam* gaussian_data = + static_cast(this->data_->mutable_cpu_data()); + this->RngGaussianFill(mu, sigma, gaussian_data); + + // Sample from Bernoulli with p = 0.3. + const TypeParam bernoulli_p = 0.3; + int* bernoulli_data = + static_cast(this->int_data_->mutable_cpu_data()); + this->RngBernoulliFill(bernoulli_p, bernoulli_data); + + // Multiply Gaussian by Bernoulli. + for (int i = 0; i < this->sample_size_; ++i) { + gaussian_data[i] *= bernoulli_data[i]; + } + + // Check that result does not violate checked properties of sparsified + // Gaussian (though it is not actually a Gaussian). + this->RngGaussianChecks(mu, sigma, gaussian_data, 1 - bernoulli_p); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniformTimesBernoulli) { + // Sample from Uniform on [-1, 1]. + const TypeParam lower = -1; + const TypeParam upper = 1; + TypeParam* uniform_data = + static_cast(this->data_->mutable_cpu_data()); + this->RngUniformFill(lower, upper, uniform_data); + + // Sample from Bernoulli with p = 0.3. + const TypeParam bernoulli_p = 0.3; + int* bernoulli_data = + static_cast(this->int_data_->mutable_cpu_data()); + this->RngBernoulliFill(bernoulli_p, bernoulli_data); + + // Multiply Uniform by Bernoulli. + for (int i = 0; i < this->sample_size_; ++i) { + uniform_data[i] *= bernoulli_data[i]; + } + + // Check that result does not violate checked properties of sparsified + // Uniform on [-1, 1] (though it is not actually uniformly distributed). + this->RngUniformChecks(lower, upper, uniform_data, 1 - bernoulli_p); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngBernoulliTimesBernoulli) { + // Sample from Bernoulli with p = 0.5. + const TypeParam p_a = 0.5; + int* bernoulli_data_a = + static_cast(this->int_data_->mutable_cpu_data()); + this->RngBernoulliFill(p_a, bernoulli_data_a); + + // Sample from Bernoulli with p = 0.3. + const TypeParam p_b = 0.3; + int* bernoulli_data_b = + static_cast(this->int_data_2_->mutable_cpu_data()); + this->RngBernoulliFill(p_b, bernoulli_data_b); + + // Multiply Bernoullis. + for (int i = 0; i < this->sample_size_; ++i) { + bernoulli_data_a[i] *= bernoulli_data_b[i]; + } + int num_ones = 0; + for (int i = 0; i < this->sample_size_; ++i) { + if (bernoulli_data_a[i] != TypeParam(0)) { + EXPECT_EQ(TypeParam(1), bernoulli_data_a[i]); + ++num_ones; + } + } + + // Check that resulting product has roughly p_a * p_b ones. + const TypeParam sample_p = this->sample_mean(bernoulli_data_a); + const TypeParam true_mean = p_a * p_b; + const TypeParam true_std = sqrt(true_mean * (1 - true_mean)); + const TypeParam bound = this->mean_bound(true_std); + EXPECT_NEAR(true_mean, sample_p, bound); +} + +#ifndef CPU_ONLY + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianGPU) { + const TypeParam mu = 0; + const TypeParam sigma = 1; + void* gaussian_gpu_data = this->data_->mutable_gpu_data(); + this->RngGaussianFillGPU(mu, sigma, gaussian_gpu_data); + const void* gaussian_data = this->data_->cpu_data(); + this->RngGaussianChecks(mu, sigma, gaussian_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussian2GPU) { + const TypeParam mu = -2; + const TypeParam sigma = 3; + void* gaussian_gpu_data = this->data_->mutable_gpu_data(); + this->RngGaussianFillGPU(mu, sigma, gaussian_gpu_data); + const void* gaussian_data = this->data_->cpu_data(); + this->RngGaussianChecks(mu, sigma, gaussian_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniformGPU) { + const TypeParam lower = 0; + const TypeParam upper = 1; + void* uniform_gpu_data = this->data_->mutable_gpu_data(); + this->RngUniformFillGPU(lower, upper, uniform_gpu_data); + const void* uniform_data = this->data_->cpu_data(); + this->RngUniformChecks(lower, upper, uniform_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniform2GPU) { + const TypeParam lower = -7.3; + const TypeParam upper = -2.3; + void* uniform_gpu_data = this->data_->mutable_gpu_data(); + this->RngUniformFillGPU(lower, upper, uniform_gpu_data); + const void* uniform_data = this->data_->cpu_data(); + this->RngUniformChecks(lower, upper, uniform_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniformIntGPU) { + unsigned int* uniform_uint_gpu_data = + static_cast(this->int_data_->mutable_gpu_data()); + this->RngUniformIntFillGPU(uniform_uint_gpu_data); + const unsigned int* uniform_uint_data = + static_cast(this->int_data_->cpu_data()); + TypeParam* uniform_data = + static_cast(this->data_->mutable_cpu_data()); + for (int i = 0; i < this->sample_size_; ++i) { + uniform_data[i] = static_cast(uniform_uint_data[i]); + } + const TypeParam lower = 0; + const TypeParam upper = UINT_MAX; + this->RngUniformChecks(lower, upper, uniform_data); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngGaussianTimesGaussianGPU) { + const TypeParam mu = 0; + const TypeParam sigma = 1; + + // Sample from 0 mean Gaussian. + TypeParam* gaussian_gpu_data_1 = + static_cast(this->data_->mutable_gpu_data()); + this->RngGaussianFillGPU(mu, sigma, gaussian_gpu_data_1); + + // Sample from 0 mean Gaussian again. + TypeParam* gaussian_gpu_data_2 = + static_cast(this->data_2_->mutable_gpu_data()); + this->RngGaussianFillGPU(mu, sigma, gaussian_gpu_data_2); + + // Multiply Gaussians. + TypeParam* gaussian_data_1 = + static_cast(this->data_->mutable_cpu_data()); + const TypeParam* gaussian_data_2 = + static_cast(this->data_2_->cpu_data()); + for (int i = 0; i < this->sample_size_; ++i) { + gaussian_data_1[i] *= gaussian_data_2[i]; + } + + // Check that result does not violate checked properties of Gaussian + // (though it is not actually a Gaussian). + TypeParam mu_product = pow(mu, 2); + TypeParam sigma_product = sqrt(pow(sigma, 2) / 2); + this->RngGaussianChecks(mu_product, sigma_product, gaussian_data_1); +} + + +TYPED_TEST(RandomNumberGeneratorTest, TestRngUniformTimesUniformGPU) { + // Sample from Uniform on [-2, 2]. + const TypeParam lower_1 = -2; + const TypeParam upper_1 = -lower_1; + TypeParam* uniform_gpu_data_1 = + static_cast(this->data_->mutable_gpu_data()); + this->RngUniformFillGPU(lower_1, upper_1, uniform_gpu_data_1); + + // Sample from Uniform on [-3, 3]. + const TypeParam lower_2 = -3; + const TypeParam upper_2 = -lower_2; + TypeParam* uniform_gpu_data_2 = + static_cast(this->data_2_->mutable_gpu_data()); + this->RngUniformFillGPU(lower_2, upper_2, uniform_gpu_data_2); + + // Multiply Uniforms. + TypeParam* uniform_data_1 = + static_cast(this->data_->mutable_cpu_data()); + const TypeParam* uniform_data_2 = + static_cast(this->data_2_->cpu_data()); + for (int i = 0; i < this->sample_size_; ++i) { + uniform_data_1[i] *= uniform_data_2[i]; + } + + // Check that result does not violate properties of Uniform on [-7, -3]. + const TypeParam lower_prod = lower_1 * upper_2; + const TypeParam upper_prod = -lower_prod; + this->RngUniformChecks(lower_prod, upper_prod, uniform_data_1); +} + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp new file mode 100644 index 00000000..e5737e43 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp @@ -0,0 +1,122 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SigmoidCrossEntropyLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_targets_(new Blob(10, 5, 1, 1)), + blob_top_loss_(new Blob()) { + // Fill the data vector + FillerParameter data_filler_param; + data_filler_param.set_std(1); + GaussianFiller data_filler(data_filler_param); + data_filler.Fill(blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + // Fill the targets vector + FillerParameter targets_filler_param; + targets_filler_param.set_min(0); + targets_filler_param.set_max(1); + UniformFiller targets_filler(targets_filler_param); + targets_filler.Fill(blob_bottom_targets_); + blob_bottom_vec_.push_back(blob_bottom_targets_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~SigmoidCrossEntropyLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_targets_; + delete blob_top_loss_; + } + + Dtype SigmoidCrossEntropyLossReference(const int count, const int num, + const Dtype* input, + const Dtype* target) { + Dtype loss = 0; + for (int i = 0; i < count; ++i) { + const Dtype prediction = 1 / (1 + exp(-input[i])); + EXPECT_LE(prediction, 1); + EXPECT_GE(prediction, 0); + EXPECT_LE(target[i], 1); + EXPECT_GE(target[i], 0); + loss -= target[i] * log(prediction + (target[i] == Dtype(0))); + loss -= (1 - target[i]) * log(1 - prediction + (target[i] == Dtype(1))); + } + return loss / num; + } + + void TestForward() { + LayerParameter layer_param; + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + FillerParameter data_filler_param; + data_filler_param.set_std(1); + GaussianFiller data_filler(data_filler_param); + FillerParameter targets_filler_param; + targets_filler_param.set_min(0.0); + targets_filler_param.set_max(1.0); + UniformFiller targets_filler(targets_filler_param); + Dtype eps = 2e-2; + for (int i = 0; i < 100; ++i) { + // Fill the data vector + data_filler.Fill(this->blob_bottom_data_); + // Fill the targets vector + targets_filler.Fill(this->blob_bottom_targets_); + SigmoidCrossEntropyLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + Dtype layer_loss = + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + const int count = this->blob_bottom_data_->count(); + const int num = this->blob_bottom_data_->num(); + const Dtype* blob_bottom_data = this->blob_bottom_data_->cpu_data(); + const Dtype* blob_bottom_targets = + this->blob_bottom_targets_->cpu_data(); + Dtype reference_loss = kLossWeight * SigmoidCrossEntropyLossReference( + count, num, blob_bottom_data, blob_bottom_targets); + EXPECT_NEAR(reference_loss, layer_loss, eps) << "debug: trial #" << i; + } + } + + Blob* const blob_bottom_data_; + Blob* const blob_bottom_targets_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SigmoidCrossEntropyLossLayerTest, TestDtypesAndDevices); + +TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestSigmoidCrossEntropyLoss) { + this->TestForward(); +} + +TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + SigmoidCrossEntropyLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_slice_layer.cpp b/caffe-crfrnn/src/caffe/test/test_slice_layer.cpp new file mode 100644 index 00000000..395be280 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_slice_layer.cpp @@ -0,0 +1,189 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class SliceLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SliceLayerTest() + : blob_bottom_(new Blob(6, 12, 2, 3)), + blob_top_0_(new Blob()), + blob_top_1_(new Blob()), + blob_top_2_(new Blob()) {} + virtual void SetUp() { + // fill the values + Caffe::set_random_seed(1701); + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_top_vec_0_.push_back(blob_top_0_); + blob_top_vec_0_.push_back(blob_top_1_); + blob_top_vec_1_.push_back(blob_top_0_); + blob_top_vec_1_.push_back(blob_top_1_); + blob_top_vec_1_.push_back(blob_top_2_); + blob_bottom_vec_.push_back(blob_bottom_); + } + + virtual void ReduceBottomBlobSize() { + blob_bottom_->Reshape(4, 5, 2, 2); + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + } + + virtual ~SliceLayerTest() { + delete blob_top_0_; delete blob_top_1_; + delete blob_top_2_; delete blob_bottom_; + } + + Blob* const blob_bottom_; + Blob* const blob_top_0_; + Blob* const blob_top_1_; + Blob* const blob_top_2_; + vector*> blob_top_vec_0_, blob_top_vec_1_; + vector*> blob_bottom_vec_; +}; + +TYPED_TEST_CASE(SliceLayerTest, TestDtypesAndDevices); + +TYPED_TEST(SliceLayerTest, TestSetupNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_slice_param()->set_slice_dim(0); + SliceLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_1_); + EXPECT_EQ(this->blob_bottom_->num(), 3 * this->blob_top_0_->num()); + EXPECT_EQ(this->blob_top_0_->num(), this->blob_top_1_->num()); + EXPECT_EQ(this->blob_top_0_->num(), this->blob_top_2_->num()); + EXPECT_EQ(this->blob_bottom_->channels(), this->blob_top_0_->channels()); + EXPECT_EQ(this->blob_bottom_->height(), this->blob_top_0_->height()); + EXPECT_EQ(this->blob_bottom_->width(), this->blob_top_0_->width()); +} + +TYPED_TEST(SliceLayerTest, TestSetupChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_slice_param()->add_slice_point(3); + SliceLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_0_); + EXPECT_EQ(this->blob_top_0_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_0_->channels(), 3); + EXPECT_EQ(this->blob_top_1_->channels(), 9); + EXPECT_EQ(this->blob_bottom_->channels(), + this->blob_top_0_->channels() + this->blob_top_1_->channels()); + EXPECT_EQ(this->blob_bottom_->height(), this->blob_top_0_->height()); + EXPECT_EQ(this->blob_bottom_->width(), this->blob_top_0_->width()); +} + +TYPED_TEST(SliceLayerTest, TestSliceAcrossNum) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.mutable_slice_param()->set_slice_dim(0); + SliceLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_0_); + const int top_num = this->blob_bottom_->num() / 2; + ASSERT_EQ(top_num, this->blob_top_0_->num()); + ASSERT_EQ(top_num, this->blob_top_1_->num()); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_0_); + for (int n = 0; n < top_num; ++n) { + for (int c = 0; c < this->blob_top_0_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), + this->blob_top_0_->data_at(n, c, h, w)); + } + } + } + for (int c = 0; c < this->blob_top_1_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n + 3, c, h, w), + this->blob_top_1_->data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST(SliceLayerTest, TestSliceAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + // Slice at 2, 8: should produce output blobs with #channels 2, 6, 4. + const int kSlicePoint0 = 2; + const int kSlicePoint1 = 8; + layer_param.mutable_slice_param()->add_slice_point(kSlicePoint0); + layer_param.mutable_slice_param()->add_slice_point(kSlicePoint1); + SliceLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_1_); + ASSERT_EQ(kSlicePoint0, this->blob_top_0_->channels()); + ASSERT_EQ(kSlicePoint1 - kSlicePoint0, this->blob_top_1_->channels()); + ASSERT_EQ(this->blob_bottom_->channels() - kSlicePoint1, + this->blob_top_2_->channels()); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_1_); + for (int n = 0; n < this->blob_bottom_->num(); ++n) { + for (int c = 0; c < this->blob_top_0_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), + this->blob_top_0_->data_at(n, c, h, w)); + } + } + } + for (int c = 0; c < this->blob_top_1_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c + kSlicePoint0, h, w), + this->blob_top_1_->data_at(n, c, h, w)); + } + } + } + for (int c = 0; c < this->blob_top_2_->channels(); ++c) { + for (int h = 0; h < this->blob_bottom_->height(); ++h) { + for (int w = 0; w < this->blob_bottom_->width(); ++w) { + EXPECT_EQ(this->blob_bottom_->data_at(n, c + kSlicePoint1, h, w), + this->blob_top_2_->data_at(n, c, h, w)); + } + } + } + } +} + +TYPED_TEST(SliceLayerTest, TestGradientAcrossNum) { + typedef typename TypeParam::Dtype Dtype; + // Gradient checks are slow; reduce blob size. + this->ReduceBottomBlobSize(); + LayerParameter layer_param; + layer_param.mutable_slice_param()->set_slice_dim(0); + SliceLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_0_); +} + +TYPED_TEST(SliceLayerTest, TestGradientAcrossChannels) { + typedef typename TypeParam::Dtype Dtype; + // Gradient checks are slow; reduce blob size. + this->ReduceBottomBlobSize(); + LayerParameter layer_param; + const int kSlicePoint = 4; + layer_param.mutable_slice_param()->add_slice_point(kSlicePoint); + SliceLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_0_); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_softmax_layer.cpp b/caffe-crfrnn/src/caffe/test/test_softmax_layer.cpp new file mode 100644 index 00000000..f6674422 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_softmax_layer.cpp @@ -0,0 +1,151 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class SoftmaxLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + SoftmaxLayerTest() + : blob_bottom_(new Blob(2, 10, 2, 3)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~SoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SoftmaxLayerTest, TestDtypesAndDevices); + +TYPED_TEST(SoftmaxLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SoftmaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test sum + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + Dtype sum = 0; + for (int j = 0; j < this->blob_top_->channels(); ++j) { + sum += this->blob_top_->data_at(i, j, k, l); + } + EXPECT_GE(sum, 0.999); + EXPECT_LE(sum, 1.001); + // Test exact values + Dtype scale = 0; + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + scale += exp(this->blob_bottom_->data_at(i, j, k, l)); + } + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, + exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + << "debug: " << i << " " << j; + EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, + exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + << "debug: " << i << " " << j; + } + } + } + } +} + +TYPED_TEST(SoftmaxLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SoftmaxLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#ifdef USE_CUDNN +template +class CuDNNSoftmaxLayerTest : public ::testing::Test { + protected: + CuDNNSoftmaxLayerTest() + : blob_bottom_(new Blob(2, 10, 2, 3)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~CuDNNSoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(CuDNNSoftmaxLayerTest, TestDtypes); + +TYPED_TEST(CuDNNSoftmaxLayerTest, TestForwardCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNSoftmaxLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Test sum + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + TypeParam sum = 0; + for (int j = 0; j < this->blob_top_->channels(); ++j) { + sum += this->blob_top_->data_at(i, j, k, l); + } + EXPECT_GE(sum, 0.999); + EXPECT_LE(sum, 1.001); + // Test exact values + TypeParam scale = 0; + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + scale += exp(this->blob_bottom_->data_at(i, j, k, l)); + } + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4, + exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + << "debug: " << i << " " << j; + EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4, + exp(this->blob_bottom_->data_at(i, j, k, l)) / scale) + << "debug: " << i << " " << j; + } + } + } + } +} + +TYPED_TEST(CuDNNSoftmaxLayerTest, TestGradientCuDNN) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; + CuDNNSoftmaxLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_softmax_with_loss_layer.cpp b/caffe-crfrnn/src/caffe/test/test_softmax_with_loss_layer.cpp new file mode 100644 index 00000000..badda3b5 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_softmax_with_loss_layer.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class SoftmaxWithLossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SoftmaxWithLossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 2, 3)), + blob_bottom_label_(new Blob(10, 1, 2, 3)), + blob_top_loss_(new Blob()) { + // fill the values + FillerParameter filler_param; + filler_param.set_std(10); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + for (int i = 0; i < blob_bottom_label_->count(); ++i) { + blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5; + } + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~SoftmaxWithLossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_loss_; + } + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SoftmaxWithLossLayerTest, TestDtypesAndDevices); + + +TYPED_TEST(SoftmaxWithLossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + layer_param.add_loss_weight(3); + SoftmaxWithLossLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, 0); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_solver.cpp b/caffe-crfrnn/src/caffe/test/test_solver.cpp new file mode 100644 index 00000000..a7dbf77f --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_solver.cpp @@ -0,0 +1,107 @@ +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::ostringstream; + +namespace caffe { + +template +class SolverTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolverFromProtoString(const string& proto) { + SolverParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + // Set the solver_mode according to current Caffe::mode. + switch (Caffe::mode()) { + case Caffe::CPU: + param.set_solver_mode(SolverParameter_SolverMode_CPU); + break; + case Caffe::GPU: + param.set_solver_mode(SolverParameter_SolverMode_GPU); + break; + default: + LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); + } + solver_.reset(new SGDSolver(param)); + } + + shared_ptr > solver_; +}; + +TYPED_TEST_CASE(SolverTest, TestDtypesAndDevices); + +TYPED_TEST(SolverTest, TestInitTrainTestNets) { + const string& proto = + "test_interval: 10 " + "test_iter: 10 " + "test_state: { stage: 'with-softmax' }" + "test_iter: 10 " + "test_state: {}" + "net_param { " + " name: 'TestNetwork' " + " layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 3 " + " height: 10 " + " width: 10 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " } " + " top: 'data' " + " top: 'label' " + " } " + " layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " } " + " bottom: 'data' " + " top: 'innerprod' " + " } " + " layers: { " + " name: 'accuracy' " + " type: ACCURACY " + " bottom: 'innerprod' " + " bottom: 'label' " + " top: 'accuracy' " + " exclude: { phase: TRAIN } " + " } " + " layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + " include: { phase: TRAIN } " + " include: { phase: TEST stage: 'with-softmax' } " + " } " + "} "; + this->InitSolverFromProtoString(proto); + ASSERT_TRUE(this->solver_->net() != NULL); + EXPECT_TRUE(this->solver_->net()->has_layer("loss")); + EXPECT_FALSE(this->solver_->net()->has_layer("accuracy")); + ASSERT_EQ(2, this->solver_->test_nets().size()); + EXPECT_TRUE(this->solver_->test_nets()[0]->has_layer("loss")); + EXPECT_TRUE(this->solver_->test_nets()[0]->has_layer("accuracy")); + EXPECT_FALSE(this->solver_->test_nets()[1]->has_layer("loss")); + EXPECT_TRUE(this->solver_->test_nets()[1]->has_layer("accuracy")); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_split_layer.cpp b/caffe-crfrnn/src/caffe/test/test_split_layer.cpp new file mode 100644 index 00000000..38e76219 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_split_layer.cpp @@ -0,0 +1,1013 @@ +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/insert_splits.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class SplitLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SplitLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 5)), + blob_top_a_(new Blob()), + blob_top_b_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_a_); + blob_top_vec_.push_back(blob_top_b_); + } + virtual ~SplitLayerTest() { + delete blob_bottom_; + delete blob_top_a_; + delete blob_top_b_; + } + Blob* const blob_bottom_; + Blob* const blob_top_a_; + Blob* const blob_top_b_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(SplitLayerTest, TestDtypesAndDevices); + +TYPED_TEST(SplitLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SplitLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_a_->num(), 2); + EXPECT_EQ(this->blob_top_a_->channels(), 3); + EXPECT_EQ(this->blob_top_a_->height(), 6); + EXPECT_EQ(this->blob_top_a_->width(), 5); + EXPECT_EQ(this->blob_top_b_->num(), 2); + EXPECT_EQ(this->blob_top_b_->channels(), 3); + EXPECT_EQ(this->blob_top_b_->height(), 6); + EXPECT_EQ(this->blob_top_b_->width(), 5); +} + +TYPED_TEST(SplitLayerTest, Test) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SplitLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + Dtype bottom_value = this->blob_bottom_->cpu_data()[i]; + EXPECT_EQ(bottom_value, this->blob_top_a_->cpu_data()[i]); + EXPECT_EQ(bottom_value, this->blob_top_b_->cpu_data()[i]); + } +} + +TYPED_TEST(SplitLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + SplitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + +class SplitLayerInsertionTest : public ::testing::Test { + protected: + void RunInsertionTest( + const string& input_param_string, const string& output_param_string) { + // Test that InsertSplits called on the proto specified by + // input_param_string results in the proto specified by + // output_param_string. + NetParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + NetParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + NetParameter actual_output_param; + InsertSplits(input_param, &actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + // Also test idempotence. + NetParameter double_split_insert_param; + InsertSplits(actual_output_param, &double_split_insert_param); + EXPECT_EQ(actual_output_param.DebugString(), + double_split_insert_param.DebugString()); + } +}; + +TEST_F(SplitLayerInsertionTest, TestNoInsertion1) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TEST_F(SplitLayerInsertionTest, TestNoInsertion2) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'data_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_split_0' " + " top: 'data_split_1' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'data_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TEST_F(SplitLayerInsertionTest, TestNoInsertionImageNet) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " name: 'data' " + " type: DATA " + " data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " batch_size: 256 " + " } " + " transform_param { " + " crop_size: 227 " + " mirror: true " + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 96 " + " kernel_size: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " name: 'relu1' " + " type: RELU " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " name: 'pool1' " + " type: POOLING " + " pooling_param { " + " pool: MAX " + " kernel_size: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " name: 'norm1' " + " type: LRN " + " lrn_param { " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " name: 'conv2' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 256 " + " group: 2 " + " kernel_size: 5 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'norm1' " + " top: 'conv2' " + "} " + "layers { " + " name: 'relu2' " + " type: RELU " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " name: 'pool2' " + " type: POOLING " + " pooling_param { " + " pool: MAX " + " kernel_size: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " name: 'norm2' " + " type: LRN " + " lrn_param { " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " name: 'conv3' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 384 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'norm2' " + " top: 'conv3' " + "} " + "layers { " + " name: 'relu3' " + " type: RELU " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " name: 'conv4' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 384 " + " group: 2 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'conv3' " + " top: 'conv4' " + "} " + "layers { " + " name: 'relu4' " + " type: RELU " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " name: 'conv5' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 256 " + " group: 2 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'conv4' " + " top: 'conv5' " + "} " + "layers { " + " name: 'relu5' " + " type: RELU " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " name: 'pool5' " + " type: POOLING " + " pooling_param { " + " kernel_size: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " name: 'fc6' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " name: 'relu6' " + " type: RELU " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'drop6' " + " type: DROPOUT " + " dropout_param { " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'fc7' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " name: 'relu7' " + " type: RELU " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " name: 'drop7' " + " type: DROPOUT " + " dropout_param { " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " name: 'fc8' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TEST_F(SplitLayerInsertionTest, TestNoInsertionWithInPlace) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'relu' " + " type: RELU " + " bottom: 'innerprod' " + " top: 'innerprod' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TEST_F(SplitLayerInsertionTest, TestLossInsertion) { + const string& input_proto = + "name: 'UnsharedWeightsNetwork' " + "force_backward: true " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights1' " + " bottom: 'data' " + " top: 'innerproduct1' " + " loss_weight: 2.5 " + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights2' " + " bottom: 'data' " + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerproduct1' " + " bottom: 'innerproduct2' " + "} "; + const string& expected_output_proto = + "name: 'UnsharedWeightsNetwork' " + "force_backward: true " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + "} " + "layers: { " + " name: 'data_data_0_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_data_0_split_0' " + " top: 'data_data_0_split_1' " + "} " + "layers: { " + " name: 'innerproduct1' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights1' " + " bottom: 'data_data_0_split_0' " + " top: 'innerproduct1' " + "} " + "layers: { " + " name: 'innerproduct1_innerproduct1_0_split' " + " type: SPLIT " + " bottom: 'innerproduct1' " + " top: 'innerproduct1_innerproduct1_0_split_0' " + " top: 'innerproduct1_innerproduct1_0_split_1' " + " loss_weight: 2.5 " + " loss_weight: 0 " + "} " + "layers: { " + " name: 'innerproduct2' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 10 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 10 " + " } " + " } " + " param: 'unsharedweights2' " + " bottom: 'data_data_0_split_1' " + " top: 'innerproduct2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerproduct1_innerproduct1_0_split_1' " + " bottom: 'innerproduct2' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TEST_F(SplitLayerInsertionTest, TestInsertion) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'innerprod3' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod3' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2' " + " bottom: 'innerprod3' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'data_data_0_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_data_0_split_0' " + " top: 'data_data_0_split_1' " + " top: 'data_data_0_split_2' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'innerprod2_innerprod2_0_split' " + " type: SPLIT " + " bottom: 'innerprod2' " + " top: 'innerprod2_innerprod2_0_split_0' " + " top: 'innerprod2_innerprod2_0_split_1' " + "} " + "layers: { " + " name: 'innerprod3' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_2' " + " top: 'innerprod3' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod2_innerprod2_0_split_0' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2_innerprod2_0_split_1' " + " bottom: 'innerprod3' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TEST_F(SplitLayerInsertionTest, TestInsertionTwoTop) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'label' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'innerprod3' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod3' " + "} " + "layers: { " + " name: 'innerprod4' " + " type: INNER_PRODUCT " + " bottom: 'label' " + " top: 'innerprod4' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod3' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2' " + " bottom: 'innerprod4' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'data_data_0_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_data_0_split_0' " + " top: 'data_data_0_split_1' " + "} " + "layers: { " + " name: 'label_data_1_split' " + " type: SPLIT " + " bottom: 'label' " + " top: 'label_data_1_split_0' " + " top: 'label_data_1_split_1' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'label_data_1_split_0' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'innerprod3' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_1' " + " top: 'innerprod3' " + "} " + "layers: { " + " name: 'innerprod4' " + " type: INNER_PRODUCT " + " bottom: 'label_data_1_split_1' " + " top: 'innerprod4' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod3' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2' " + " bottom: 'innerprod4' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TEST_F(SplitLayerInsertionTest, TestInputInsertion) { + const string& input_proto = + "name: 'TestNetwork' " + "input: 'data' " + "input_dim: 10 " + "input_dim: 3 " + "input_dim: 227 " + "input_dim: 227 " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "input: 'data' " + "input_dim: 10 " + "input_dim: 3 " + "input_dim: 227 " + "input_dim: 227 " + "layers: { " + " name: 'data_input_0_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_input_0_split_0' " + " top: 'data_input_0_split_1' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data_input_0_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'data_input_0_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TEST_F(SplitLayerInsertionTest, TestWithInPlace) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'relu1' " + " type: RELU " + " bottom: 'innerprod1' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'innerprod1' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1' " + " bottom: 'label' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2' " + " bottom: 'data' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " name: 'data' " + " type: DATA " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'data_data_0_split' " + " type: SPLIT " + " bottom: 'data' " + " top: 'data_data_0_split_0' " + " top: 'data_data_0_split_1' " + "} " + "layers: { " + " name: 'innerprod1' " + " type: INNER_PRODUCT " + " bottom: 'data_data_0_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'relu1' " + " type: RELU " + " bottom: 'innerprod1' " + " top: 'innerprod1' " + "} " + "layers: { " + " name: 'innerprod1_relu1_0_split' " + " type: SPLIT " + " bottom: 'innerprod1' " + " top: 'innerprod1_relu1_0_split_0' " + " top: 'innerprod1_relu1_0_split_1' " + "} " + "layers: { " + " name: 'innerprod2' " + " type: INNER_PRODUCT " + " bottom: 'innerprod1_relu1_0_split_0' " + " top: 'innerprod2' " + "} " + "layers: { " + " name: 'loss1' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod1_relu1_0_split_1' " + " bottom: 'label' " + "} " + "layers: { " + " name: 'loss2' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod2' " + " bottom: 'data_data_0_split_1' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_stochastic_pooling.cpp b/caffe-crfrnn/src/caffe/test/test_stochastic_pooling.cpp new file mode 100644 index 00000000..ad515100 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_stochastic_pooling.cpp @@ -0,0 +1,163 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +using std::min; + +namespace caffe { + +template +class StochasticPoolingLayerTest : public ::testing::Test { + protected: + StochasticPoolingLayerTest() + : blob_bottom_(new Blob()), + blob_top_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701); + blob_bottom_->Reshape(2, 3, 6, 5); + // fill the values + FillerParameter filler_param; + filler_param.set_min(0.1); + filler_param.set_max(1.); + UniformFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + + virtual ~StochasticPoolingLayerTest() { + delete blob_bottom_; delete blob_top_; + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(StochasticPoolingLayerTest, TestDtypes); + +TYPED_TEST(StochasticPoolingLayerTest, TestSetup) { + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), 3); + EXPECT_EQ(this->blob_top_->width(), 2); +} + +TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPU) { + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TRAIN); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Check if the output is correct - it should do random sampling + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + TypeParam total = 0; + for (int n = 0; n < this->blob_top_->num(); ++n) { + for (int c = 0; c < this->blob_top_->channels(); ++c) { + for (int ph = 0; ph < this->blob_top_->height(); ++ph) { + for (int pw = 0; pw < this->blob_top_->width(); ++pw) { + TypeParam pooled = top_data[this->blob_top_->offset(n, c, ph, pw)]; + total += pooled; + int hstart = ph * 2; + int hend = min(hstart + 3, this->blob_bottom_->height()); + int wstart = pw * 2; + int wend = min(wstart + 3, this->blob_bottom_->width()); + bool has_equal = false; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + has_equal |= (pooled == bottom_data[this->blob_bottom_-> + offset(n, c, h, w)]); + } + } + EXPECT_TRUE(has_equal); + } + } + } + } + // When we are doing stochastic pooling, the average we get should be higher + // than the simple data average since we are weighting more on higher-valued + // ones. + EXPECT_GE(total / this->blob_top_->count(), 0.55); +} + +TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) { + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TEST); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC); + PoolingLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Check if the output is correct - it should do random sampling + const TypeParam* bottom_data = this->blob_bottom_->cpu_data(); + const TypeParam* top_data = this->blob_top_->cpu_data(); + for (int n = 0; n < this->blob_top_->num(); ++n) { + for (int c = 0; c < this->blob_top_->channels(); ++c) { + for (int ph = 0; ph < this->blob_top_->height(); ++ph) { + for (int pw = 0; pw < this->blob_top_->width(); ++pw) { + TypeParam pooled = top_data[this->blob_top_->offset(n, c, ph, pw)]; + int hstart = ph * 2; + int hend = min(hstart + 3, this->blob_bottom_->height()); + int wstart = pw * 2; + int wend = min(wstart + 3, this->blob_bottom_->width()); + bool smaller_than_max = false; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + smaller_than_max |= (pooled <= bottom_data[this->blob_bottom_-> + offset(n, c, h, w)]); + } + } + EXPECT_TRUE(smaller_than_max); + } + } + } + } +} + +TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) { + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TRAIN); + LayerParameter layer_param; + PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); + pooling_param->set_kernel_size(3); + pooling_param->set_stride(2); + pooling_param->set_pool(PoolingParameter_PoolMethod_STOCHASTIC); + PoolingLayer layer(layer_param); + GradientChecker checker(1e-4, 1e-2); + // it is too expensive to call curand multiple times, so we don't do an + // exhaustive gradient check. + checker.CheckGradient(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_syncedmem.cpp b/caffe-crfrnn/src/caffe/test/test_syncedmem.cpp new file mode 100644 index 00000000..b946233d --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_syncedmem.cpp @@ -0,0 +1,126 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/device_alternate.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +class SyncedMemoryTest : public ::testing::Test {}; + +TEST_F(SyncedMemoryTest, TestInitialization) { + SyncedMemory mem(10); + EXPECT_EQ(mem.head(), SyncedMemory::UNINITIALIZED); + EXPECT_EQ(mem.size(), 10); + SyncedMemory* p_mem = new SyncedMemory(10 * sizeof(float)); + EXPECT_EQ(p_mem->size(), 10 * sizeof(float)); + delete p_mem; +} + +#ifndef CPU_ONLY // GPU test + +TEST_F(SyncedMemoryTest, TestAllocationCPUGPU) { + SyncedMemory mem(10); + EXPECT_TRUE(mem.cpu_data()); + EXPECT_TRUE(mem.gpu_data()); + EXPECT_TRUE(mem.mutable_cpu_data()); + EXPECT_TRUE(mem.mutable_gpu_data()); +} + +#endif + +TEST_F(SyncedMemoryTest, TestAllocationCPU) { + SyncedMemory mem(10); + EXPECT_TRUE(mem.cpu_data()); + EXPECT_TRUE(mem.mutable_cpu_data()); +} + +#ifndef CPU_ONLY // GPU test + +TEST_F(SyncedMemoryTest, TestAllocationGPU) { + SyncedMemory mem(10); + EXPECT_TRUE(mem.gpu_data()); + EXPECT_TRUE(mem.mutable_gpu_data()); +} + +#endif + +TEST_F(SyncedMemoryTest, TestCPUWrite) { + SyncedMemory mem(10); + void* cpu_data = mem.mutable_cpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU); + caffe_memset(mem.size(), 1, cpu_data); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(cpu_data))[i], 1); + } + // do another round + cpu_data = mem.mutable_cpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU); + caffe_memset(mem.size(), 2, cpu_data); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(cpu_data))[i], 2); + } +} + +#ifndef CPU_ONLY // GPU test + +TEST_F(SyncedMemoryTest, TestGPURead) { + SyncedMemory mem(10); + void* cpu_data = mem.mutable_cpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU); + caffe_memset(mem.size(), 1, cpu_data); + const void* gpu_data = mem.gpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); + // check if values are the same + char* recovered_value = new char[10]; + caffe_gpu_memcpy(10, gpu_data, recovered_value); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(recovered_value))[i], 1); + } + // do another round + cpu_data = mem.mutable_cpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU); + caffe_memset(mem.size(), 2, cpu_data); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(cpu_data))[i], 2); + } + gpu_data = mem.gpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); + // check if values are the same + caffe_gpu_memcpy(10, gpu_data, recovered_value); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(recovered_value))[i], 2); + } + delete[] recovered_value; +} + +TEST_F(SyncedMemoryTest, TestGPUWrite) { + SyncedMemory mem(10); + void* gpu_data = mem.mutable_gpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU); + caffe_gpu_memset(mem.size(), 1, gpu_data); + const void* cpu_data = mem.cpu_data(); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(cpu_data))[i], 1); + } + EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); + + gpu_data = mem.mutable_gpu_data(); + EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU); + caffe_gpu_memset(mem.size(), 2, gpu_data); + cpu_data = mem.cpu_data(); + for (int i = 0; i < mem.size(); ++i) { + EXPECT_EQ((static_cast(cpu_data))[i], 2); + } + EXPECT_EQ(mem.head(), SyncedMemory::SYNCED); +} + +#endif + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_tanh_layer.cpp b/caffe-crfrnn/src/caffe/test/test_tanh_layer.cpp new file mode 100644 index 00000000..5dc92832 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_tanh_layer.cpp @@ -0,0 +1,101 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +double tanh_naive(double x) { + if (x < -40) { + // avoid negative overflow + return -1; + } else if (x > 40) { + // avoid positive overflow + return 1; + } else { + // exact expression for tanh, which is unstable for large x + double exp2x = exp(2 * x); + return (exp2x - 1.0) / (exp2x + 1.0); + } +} + +template +class TanHLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + TanHLayerTest() + : blob_bottom_(new Blob(2, 3, 4, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + FillerParameter filler_param; + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~TanHLayerTest() { delete blob_bottom_; delete blob_top_; } + + void TestForward(Dtype filler_std) { + FillerParameter filler_param; + filler_param.set_std(filler_std); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + + LayerParameter layer_param; + TanHLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const Dtype min_precision = 1e-5; + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + Dtype expected_value = tanh_naive(bottom_data[i]); + Dtype precision = std::max( + Dtype(std::abs(expected_value * Dtype(1e-4))), min_precision); + EXPECT_NEAR(expected_value, top_data[i], precision); + } + } + + void TestBackward(Dtype filler_std) { + FillerParameter filler_param; + filler_param.set_std(filler_std); + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + + LayerParameter layer_param; + TanHLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(TanHLayerTest, TestDtypesAndDevices); + +TYPED_TEST(TanHLayerTest, TestTanH) { + this->TestForward(1.0); +} + +TYPED_TEST(TanHLayerTest, TestTanHOverflow) { + // this will fail if tanh overflow is not properly handled + this->TestForward(10000.0); +} + +TYPED_TEST(TanHLayerTest, TestTanHGradient) { + this->TestBackward(1.0); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_threshold_layer.cpp b/caffe-crfrnn/src/caffe/test/test_threshold_layer.cpp new file mode 100644 index 00000000..05ce8212 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_threshold_layer.cpp @@ -0,0 +1,98 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class ThresholdLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + ThresholdLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 5)), + blob_top_(new Blob()) { + Caffe::set_random_seed(1701); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~ThresholdLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(ThresholdLayerTest, TestDtypesAndDevices); + + +TYPED_TEST(ThresholdLayerTest, TestSetup) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ThresholdLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num()); + EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels()); + EXPECT_EQ(this->blob_top_->height(), this->blob_bottom_->height()); + EXPECT_EQ(this->blob_top_->width(), this->blob_bottom_->width()); +} + +TYPED_TEST(ThresholdLayerTest, Test) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ThresholdLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const Dtype threshold_ = layer_param.threshold_param().threshold(); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_LE(top_data[i], 1.); + if (top_data[i] == 0) { + EXPECT_LE(bottom_data[i], threshold_); + } + if (top_data[i] == 1) { + EXPECT_GT(bottom_data[i], threshold_); + } + } +} + +TYPED_TEST(ThresholdLayerTest, Test2) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + ThresholdParameter* threshold_param = + layer_param.mutable_threshold_param(); + threshold_param->set_threshold(0.5); + ThresholdLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + // Now, check values + const Dtype* bottom_data = this->blob_bottom_->cpu_data(); + const Dtype* top_data = this->blob_top_->cpu_data(); + const Dtype threshold_ = layer_param.threshold_param().threshold(); + EXPECT_FLOAT_EQ(threshold_, 0.5); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_GE(top_data[i], 0.); + EXPECT_LE(top_data[i], 1.); + if (top_data[i] == 0) { + EXPECT_LE(bottom_data[i], threshold_); + } + if (top_data[i] == 1) { + EXPECT_GT(bottom_data[i], threshold_); + } + } +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_upgrade_proto.cpp b/caffe-crfrnn/src/caffe/test/test_upgrade_proto.cpp new file mode 100644 index 00000000..52e7f1f9 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_upgrade_proto.cpp @@ -0,0 +1,2443 @@ +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/util/upgrade_proto.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +class PaddingLayerUpgradeTest : public ::testing::Test { + protected: + void RunPaddingUpgradeTest( + const string& input_param_string, const string& output_param_string) { + // Test that UpgradeV0PaddingLayers called on the proto specified by + // input_param_string results in the proto specified by + // output_param_string. + NetParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + NetParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + NetParameter actual_output_param; + UpgradeV0PaddingLayers(input_param, &actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + // Also test idempotence. + NetParameter double_pad_upgrade_param; + UpgradeV0PaddingLayers(actual_output_param, &double_pad_upgrade_param); + EXPECT_EQ(actual_output_param.DebugString(), + double_pad_upgrade_param.DebugString()); + } +}; + +TEST_F(PaddingLayerUpgradeTest, TestSimple) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'pad1' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'data' " + " top: 'pad1' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunPaddingUpgradeTest(input_proto, expected_output_proto); +} + +TEST_F(PaddingLayerUpgradeTest, TestTwoTops) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'pad1' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'data' " + " top: 'pad1' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad1' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunPaddingUpgradeTest(input_proto, expected_output_proto); +} + +TEST_F(PaddingLayerUpgradeTest, TestImageNet) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'pool1' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " layer { " + " name: 'norm1' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " layer { " + " name: 'pad2' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'norm1' " + " top: 'pad2' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 5 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'relu2' " + " type: 'relu' " + " } " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'pool2' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " layer { " + " name: 'norm2' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " layer { " + " name: 'pad3' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'norm2' " + " top: 'pad3' " + "} " + "layers { " + " layer { " + " name: 'conv3' " + " type: 'conv' " + " num_output: 384 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'relu3' " + " type: 'relu' " + " } " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'pad4' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv3' " + " top: 'pad4' " + "} " + "layers { " + " layer { " + " name: 'conv4' " + " type: 'conv' " + " num_output: 384 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'relu4' " + " type: 'relu' " + " } " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'pad5' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv4' " + " top: 'pad5' " + "} " + "layers { " + " layer { " + " name: 'conv5' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'relu5' " + " type: 'relu' " + " } " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'pool5' " + " type: 'pool' " + " kernelsize: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " layer { " + " name: 'fc6' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'relu6' " + " type: 'relu' " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'drop6' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'fc7' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'relu7' " + " type: 'relu' " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'drop7' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'pool1' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " layer { " + " name: 'norm1' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 5 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'norm1' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'relu2' " + " type: 'relu' " + " } " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'pool2' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " layer { " + " name: 'norm2' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " layer { " + " name: 'conv3' " + " type: 'conv' " + " num_output: 384 " + " kernelsize: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'norm2' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'relu3' " + " type: 'relu' " + " } " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'conv4' " + " type: 'conv' " + " num_output: 384 " + " group: 2 " + " kernelsize: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv3' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'relu4' " + " type: 'relu' " + " } " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'conv5' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv4' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'relu5' " + " type: 'relu' " + " } " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'pool5' " + " type: 'pool' " + " kernelsize: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " layer { " + " name: 'fc6' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'relu6' " + " type: 'relu' " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'drop6' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'fc7' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'relu7' " + " type: 'relu' " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'drop7' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunPaddingUpgradeTest(input_proto, expected_output_proto); +} + +class V0UpgradeTest : public ::testing::Test { + protected: + void RunV0UpgradeTest( + const string& input_param_string, const string& output_param_string) { + // Test that UpgradeV0Net called on the NetParameter proto specified by + // input_param_string results in the NetParameter proto specified by + // output_param_string. + NetParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + NetParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + NetParameter actual_output_param; + UpgradeV0Net(input_param, &actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + } +}; + +TEST_F(V0UpgradeTest, TestSimple) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'pad1' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'data' " + " top: 'pad1' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "layers { " + " name: 'data' " + " type: DATA " + " data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " batch_size: 256 " + " } " + " transform_param { " + " crop_size: 227 " + " mirror: true " + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 96 " + " kernel_size: 11 " + " stride: 4 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " name: 'fc8' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'conv1' " + " top: 'fc8' " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunV0UpgradeTest(input_proto, expected_output_proto); +} + +// Test any layer or parameter upgrades not covered by other tests. +TEST_F(V0UpgradeTest, TestAllParams) { + const string& input_proto = + "name: 'CaffeNet' " + "input: 'input_data' " + "input_dim: 64 " + "input_dim: 3 " + "input_dim: 32 " + "input_dim: 32 " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " scale: 0.25 " + " rand_skip: 73 " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'images' " + " type: 'images' " + " source: '/home/jiayq/Data/ILSVRC12/train-images' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " scale: 0.25 " + " rand_skip: 73 " + " shuffle_images: true " + " new_height: 40 " + " new_width: 30 " + " } " + " top: 'images_data' " + " top: 'images_label' " + "} " + "layers { " + " layer { " + " name: 'window_data' " + " type: 'window_data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " det_fg_threshold: 0.25 " + " det_bg_threshold: 0.75 " + " det_fg_fraction: 0.5 " + " det_context_pad: 16 " + " det_crop_mode: 'square' " + " } " + " top: 'window_data' " + " top: 'window_label' " + "} " + "layers { " + " layer { " + " name: 'hdf5data' " + " type: 'hdf5_data' " + " source: '/my/hdf5/data' " + " batchsize: 256 " + " } " + " top: 'hdf5data' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " biasterm: false " + " pad: 4 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 3. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'pool1ave' " + " type: 'pool' " + " pool: AVE " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1ave' " + "} " + "layers { " + " layer { " + " name: 'pool1stoch' " + " type: 'pool' " + " pool: STOCHASTIC " + " kernelsize: 4 " + " stride: 5 " + " } " + " bottom: 'conv1' " + " top: 'pool1stoch' " + "} " + "layers { " + " layer { " + " name: 'concat' " + " type: 'concat' " + " concat_dim: 2 " + " } " + " bottom: 'pool1ave' " + " bottom: 'pool1stoch' " + " top: 'pool1concat' " + "} " + "layers { " + " layer { " + " name: 'norm1' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1concat' " + " top: 'norm1' " + "} " + "layers { " + " layer { " + " name: 'fc6' " + " type: 'innerproduct' " + " num_output: 4096 " + " biasterm: false " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'norm1' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'relu6' " + " type: 'relu' " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'drop6' " + " type: 'dropout' " + " dropout_ratio: 0.2 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'infogain_loss' " + " source: '/my/infogain/matrix' " + " } " + " bottom: 'fc6' " + " bottom: 'label' " + "} " + "layers { " + " layer { " + " name: 'accuracy' " + " type: 'accuracy' " + " } " + "} " + "layers { " + " layer { " + " name: 'bnll' " + " type: 'bnll' " + " } " + "} " + "layers { " + " layer { " + " name: 'euclidean_loss' " + " type: 'euclidean_loss' " + " } " + "} " + "layers { " + " layer { " + " name: 'flatten' " + " type: 'flatten' " + " } " + "} " + "layers { " + " layer { " + " name: 'hdf5_output' " + " type: 'hdf5_output' " + " hdf5_output_param { " + " file_name: '/my/hdf5/output/file' " + " } " + " } " + "} " + "layers { " + " layer { " + " name: 'im2col' " + " type: 'im2col' " + " } " + "} " + "layers { " + " layer { " + " name: 'images' " + " type: 'images' " + " } " + "} " + "layers { " + " layer { " + " name: 'multinomial_logistic_loss' " + " type: 'multinomial_logistic_loss' " + " } " + "} " + "layers { " + " layer { " + " name: 'sigmoid' " + " type: 'sigmoid' " + " } " + "} " + "layers { " + " layer { " + " name: 'softmax' " + " type: 'softmax' " + " } " + "} " + "layers { " + " layer { " + " name: 'split' " + " type: 'split' " + " } " + "} " + "layers { " + " layer { " + " name: 'tanh' " + " type: 'tanh' " + " } " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "input: 'input_data' " + "input_dim: 64 " + "input_dim: 3 " + "input_dim: 32 " + "input_dim: 32 " + "layers { " + " name: 'data' " + " type: DATA " + " data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " batch_size: 256 " + " rand_skip: 73 " + " } " + " transform_param { " + " crop_size: 227 " + " mirror: true " + " scale: 0.25 " + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " name: 'images' " + " type: IMAGE_DATA " + " image_data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-images' " + " batch_size: 256 " + " rand_skip: 73 " + " shuffle: true " + " new_height: 40 " + " new_width: 30 " + " } " + " transform_param {" + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " crop_size: 227 " + " mirror: true " + " scale: 0.25 " + " } " + " top: 'images_data' " + " top: 'images_label' " + "} " + "layers { " + " name: 'window_data' " + " type: WINDOW_DATA " + " window_data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " batch_size: 256 " + " fg_threshold: 0.25 " + " bg_threshold: 0.75 " + " fg_fraction: 0.5 " + " context_pad: 16 " + " crop_mode: 'square' " + " } " + " transform_param { " + " mirror: true " + " crop_size: 227 " + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " }" + " top: 'window_data' " + " top: 'window_label' " + "} " + "layers { " + " name: 'hdf5data' " + " type: HDF5_DATA " + " hdf5_data_param { " + " source: '/my/hdf5/data' " + " batch_size: 256 " + " } " + " top: 'hdf5data' " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 96 " + " bias_term: false " + " pad: 4 " + " kernel_size: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 3. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " name: 'pool1ave' " + " type: POOLING " + " pooling_param { " + " pool: AVE " + " kernel_size: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1ave' " + "} " + "layers { " + " name: 'pool1stoch' " + " type: POOLING " + " pooling_param { " + " pool: STOCHASTIC " + " kernel_size: 4 " + " stride: 5 " + " } " + " bottom: 'conv1' " + " top: 'pool1stoch' " + "} " + "layers { " + " name: 'concat' " + " type: CONCAT " + " concat_param { " + " concat_dim: 2 " + " } " + " bottom: 'pool1ave' " + " bottom: 'pool1stoch' " + " top: 'pool1concat' " + "} " + "layers { " + " name: 'norm1' " + " type: LRN " + " lrn_param { " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1concat' " + " top: 'norm1' " + "} " + "layers { " + " name: 'fc6' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 4096 " + " bias_term: false " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'norm1' " + " top: 'fc6' " + "} " + "layers { " + " name: 'relu6' " + " type: RELU " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'drop6' " + " type: DROPOUT " + " dropout_param { " + " dropout_ratio: 0.2 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'loss' " + " type: INFOGAIN_LOSS " + " infogain_loss_param { " + " source: '/my/infogain/matrix' " + " } " + " bottom: 'fc6' " + " bottom: 'label' " + "} " + "layers { " + " name: 'accuracy' " + " type: ACCURACY " + "} " + "layers { " + " name: 'bnll' " + " type: BNLL " + "} " + "layers { " + " name: 'euclidean_loss' " + " type: EUCLIDEAN_LOSS " + "} " + "layers { " + " name: 'flatten' " + " type: FLATTEN " + "} " + "layers { " + " name: 'hdf5_output' " + " type: HDF5_OUTPUT " + " hdf5_output_param { " + " file_name: '/my/hdf5/output/file' " + " } " + "} " + "layers { " + " name: 'im2col' " + " type: IM2COL " + "} " + "layers { " + " name: 'images' " + " type: IMAGE_DATA " + "} " + "layers { " + " name: 'multinomial_logistic_loss' " + " type: MULTINOMIAL_LOGISTIC_LOSS " + "} " + "layers { " + " name: 'sigmoid' " + " type: SIGMOID " + "} " + "layers { " + " name: 'softmax' " + " type: SOFTMAX " + "} " + "layers { " + " name: 'split' " + " type: SPLIT " + "} " + "layers { " + " name: 'tanh' " + " type: TANH " + "} "; + this->RunV0UpgradeTest(input_proto, expected_output_proto); +} + +TEST_F(V0UpgradeTest, TestImageNet) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'pool1' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " layer { " + " name: 'norm1' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " layer { " + " name: 'pad2' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'norm1' " + " top: 'pad2' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 5 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'relu2' " + " type: 'relu' " + " } " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'pool2' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " layer { " + " name: 'norm2' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " layer { " + " name: 'pad3' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'norm2' " + " top: 'pad3' " + "} " + "layers { " + " layer { " + " name: 'conv3' " + " type: 'conv' " + " num_output: 384 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'relu3' " + " type: 'relu' " + " } " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'pad4' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv3' " + " top: 'pad4' " + "} " + "layers { " + " layer { " + " name: 'conv4' " + " type: 'conv' " + " num_output: 384 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'relu4' " + " type: 'relu' " + " } " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'pad5' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv4' " + " top: 'pad5' " + "} " + "layers { " + " layer { " + " name: 'conv5' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'relu5' " + " type: 'relu' " + " } " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'pool5' " + " type: 'pool' " + " kernelsize: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " layer { " + " name: 'fc6' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'relu6' " + " type: 'relu' " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'drop6' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'fc7' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'relu7' " + " type: 'relu' " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'drop7' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + const string& expected_output_proto = + "name: 'CaffeNet' " + "layers { " + " name: 'data' " + " type: DATA " + " data_param { " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " batch_size: 256 " + " } " + " transform_param { " + " crop_size: 227 " + " mirror: true " + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " name: 'conv1' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 96 " + " kernel_size: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " name: 'relu1' " + " type: RELU " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " name: 'pool1' " + " type: POOLING " + " pooling_param { " + " pool: MAX " + " kernel_size: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " name: 'norm1' " + " type: LRN " + " lrn_param { " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " name: 'conv2' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 256 " + " group: 2 " + " kernel_size: 5 " + " pad: 2 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'norm1' " + " top: 'conv2' " + "} " + "layers { " + " name: 'relu2' " + " type: RELU " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " name: 'pool2' " + " type: POOLING " + " pooling_param { " + " pool: MAX " + " kernel_size: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " name: 'norm2' " + " type: LRN " + " lrn_param { " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " name: 'conv3' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 384 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'norm2' " + " top: 'conv3' " + "} " + "layers { " + " name: 'relu3' " + " type: RELU " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " name: 'conv4' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 384 " + " group: 2 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'conv3' " + " top: 'conv4' " + "} " + "layers { " + " name: 'relu4' " + " type: RELU " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " name: 'conv5' " + " type: CONVOLUTION " + " convolution_param { " + " num_output: 256 " + " group: 2 " + " kernel_size: 3 " + " pad: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'conv4' " + " top: 'conv5' " + "} " + "layers { " + " name: 'relu5' " + " type: RELU " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " name: 'pool5' " + " type: POOLING " + " pooling_param { " + " kernel_size: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " name: 'fc6' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " name: 'relu6' " + " type: RELU " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'drop6' " + " type: DROPOUT " + " dropout_param { " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " name: 'fc7' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " name: 'relu7' " + " type: RELU " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " name: 'drop7' " + " type: DROPOUT " + " dropout_param { " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " name: 'fc8' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunV0UpgradeTest(input_proto, expected_output_proto); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/test/test_util_blas.cpp b/caffe-crfrnn/src/caffe/test/test_util_blas.cpp new file mode 100644 index 00000000..8770f309 --- /dev/null +++ b/caffe-crfrnn/src/caffe/test/test_util_blas.cpp @@ -0,0 +1,134 @@ +#ifndef CPU_ONLY // CPU-GPU test + +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/util/device_alternate.hpp" +#include "caffe/util/math_functions.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; + +template +class GemmTest : public ::testing::Test {}; + +TYPED_TEST_CASE(GemmTest, TestDtypes); + +TYPED_TEST(GemmTest, TestGemmCPUGPU) { + Blob A(1, 1, 2, 3); + Blob B(1, 1, 3, 4); + Blob C(1, 1, 2, 4); + TypeParam data[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + TypeParam A_reshape_data[6] = {1, 4, 2, 5, 3, 6}; + TypeParam B_reshape_data[12] = {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}; + TypeParam result[8] = {38, 44, 50, 56, 83, 98, 113, 128}; + caffe_copy(6, data, A.mutable_cpu_data()); + caffe_copy(12, data, B.mutable_cpu_data()); + + if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) { + // [1, 2, 3; 4 5 6] * [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12]; + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1., + A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1., + A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + + // Test when we have a transposed A + A.Reshape(1, 1, 3, 2); + caffe_copy(6, A_reshape_data, A.mutable_cpu_data()); + caffe_cpu_gemm(CblasTrans, CblasNoTrans, 2, 4, 3, 1., + A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + caffe_gpu_gemm(CblasTrans, CblasNoTrans, 2, 4, 3, 1., + A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + + // Test when we have a transposed A and a transposed B too + B.Reshape(1, 1, 4, 3); + caffe_copy(12, B_reshape_data, B.mutable_cpu_data()); + caffe_cpu_gemm(CblasTrans, CblasTrans, 2, 4, 3, 1., + A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + caffe_gpu_gemm(CblasTrans, CblasTrans, 2, 4, 3, 1., + A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + + // Test when we have a transposed B + A.Reshape(1, 1, 2, 3); + caffe_copy(6, data, A.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasTrans, 2, 4, 3, 1., + A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + caffe_gpu_gemm(CblasNoTrans, CblasTrans, 2, 4, 3, 1., + A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data()); + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(C.cpu_data()[i], result[i]); + } + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + + +TYPED_TEST(GemmTest, TestGemvCPUGPU) { + Blob A(1, 1, 2, 3); + Blob x(1, 1, 1, 3); + Blob y(1, 1, 1, 2); + TypeParam data[6] = {1, 2, 3, 4, 5, 6}; + TypeParam result_2[2] = {14, 32}; + TypeParam result_3[3] = {9, 12, 15}; + caffe_copy(6, data, A.mutable_cpu_data()); + caffe_copy(3, data, x.mutable_cpu_data()); + + if (sizeof(TypeParam) == 4 || CAFFE_TEST_CUDA_PROP.major >= 2) { + caffe_cpu_gemv(CblasNoTrans, 2, 3, 1., A.cpu_data(), + x.cpu_data(), 0., y.mutable_cpu_data()); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(y.cpu_data()[i], result_2[i]); + } + caffe_gpu_gemv(CblasNoTrans, 2, 3, 1., A.gpu_data(), + x.gpu_data(), 0., y.mutable_gpu_data()); + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(y.cpu_data()[i], result_2[i]); + } + + // Test transpose case + caffe_copy(2, data, y.mutable_cpu_data()); + caffe_cpu_gemv(CblasTrans, 2, 3, 1., A.cpu_data(), + y.cpu_data(), 0., x.mutable_cpu_data()); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(x.cpu_data()[i], result_3[i]); + } + caffe_gpu_gemv(CblasTrans, 2, 3, 1., A.gpu_data(), + y.gpu_data(), 0., x.mutable_gpu_data()); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(x.cpu_data()[i], result_3[i]); + } + } else { + LOG(ERROR) << "Skipping test due to old architecture."; + } +} + +} // namespace caffe + +#endif // CPU_ONLY diff --git a/caffe-crfrnn/src/caffe/util/benchmark.cpp b/caffe-crfrnn/src/caffe/util/benchmark.cpp new file mode 100644 index 00000000..1d269c35 --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/benchmark.cpp @@ -0,0 +1,168 @@ +#include + +#include "caffe/common.hpp" +#include "caffe/util/benchmark.hpp" + +namespace caffe { + +Timer::Timer() + : initted_(false), + running_(false), + has_run_at_least_once_(false) { + Init(); +} + +Timer::~Timer() { + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventDestroy(start_gpu_)); + CUDA_CHECK(cudaEventDestroy(stop_gpu_)); +#else + NO_GPU; +#endif + } +} + +void Timer::Start() { + if (!running()) { + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventRecord(start_gpu_, 0)); +#else + NO_GPU; +#endif + } else { + start_cpu_ = boost::posix_time::microsec_clock::local_time(); + } + running_ = true; + has_run_at_least_once_ = true; + } +} + +void Timer::Stop() { + if (running()) { + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventRecord(stop_gpu_, 0)); + CUDA_CHECK(cudaEventSynchronize(stop_gpu_)); +#else + NO_GPU; +#endif + } else { + stop_cpu_ = boost::posix_time::microsec_clock::local_time(); + } + running_ = false; + } +} + + +float Timer::MicroSeconds() { + if (!has_run_at_least_once()) { + LOG(WARNING) << "Timer has never been run before reading time."; + return 0; + } + if (running()) { + Stop(); + } + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventElapsedTime(&elapsed_milliseconds_, start_gpu_, + stop_gpu_)); + // Cuda only measure milliseconds + elapsed_microseconds_ = elapsed_milliseconds_ * 1000; +#else + NO_GPU; +#endif + } else { + elapsed_microseconds_ = (stop_cpu_ - start_cpu_).total_microseconds(); + } + return elapsed_microseconds_; +} + +float Timer::MilliSeconds() { + if (!has_run_at_least_once()) { + LOG(WARNING) << "Timer has never been run before reading time."; + return 0; + } + if (running()) { + Stop(); + } + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventElapsedTime(&elapsed_milliseconds_, start_gpu_, + stop_gpu_)); +#else + NO_GPU; +#endif + } else { + elapsed_milliseconds_ = (stop_cpu_ - start_cpu_).total_milliseconds(); + } + return elapsed_milliseconds_; +} + +float Timer::Seconds() { + return MilliSeconds() / 1000.; +} + +void Timer::Init() { + if (!initted()) { + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaEventCreate(&start_gpu_)); + CUDA_CHECK(cudaEventCreate(&stop_gpu_)); +#else + NO_GPU; +#endif + } + initted_ = true; + } +} + +CPUTimer::CPUTimer() { + this->initted_ = true; + this->running_ = false; + this->has_run_at_least_once_ = false; +} + +void CPUTimer::Start() { + if (!running()) { + this->start_cpu_ = boost::posix_time::microsec_clock::local_time(); + this->running_ = true; + this->has_run_at_least_once_ = true; + } +} + +void CPUTimer::Stop() { + if (running()) { + this->stop_cpu_ = boost::posix_time::microsec_clock::local_time(); + this->running_ = false; + } +} + +float CPUTimer::MilliSeconds() { + if (!has_run_at_least_once()) { + LOG(WARNING) << "Timer has never been run before reading time."; + return 0; + } + if (running()) { + Stop(); + } + this->elapsed_milliseconds_ = (this->stop_cpu_ - + this->start_cpu_).total_milliseconds(); + return this->elapsed_milliseconds_; +} + +float CPUTimer::MicroSeconds() { + if (!has_run_at_least_once()) { + LOG(WARNING) << "Timer has never been run before reading time."; + return 0; + } + if (running()) { + Stop(); + } + this->elapsed_microseconds_ = (this->stop_cpu_ - + this->start_cpu_).total_microseconds(); + return this->elapsed_microseconds_; +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/cudnn.cpp b/caffe-crfrnn/src/caffe/util/cudnn.cpp new file mode 100644 index 00000000..1ea84b0e --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/cudnn.cpp @@ -0,0 +1,24 @@ +#ifdef USE_CUDNN +#include "caffe/util/cudnn.hpp" + +namespace caffe { +namespace cudnn { + +float dataType::oneval = 1.0; +float dataType::zeroval = 0.0; +const void* dataType::one = + static_cast(&dataType::oneval); +const void* dataType::zero = + static_cast(&dataType::zeroval); + +double dataType::oneval = 1.0; +double dataType::zeroval = 0.0; +const void* dataType::one = + static_cast(&dataType::oneval); +const void* dataType::zero = + static_cast(&dataType::zeroval); + +} // namespace cudnn +} // namespace caffe +#endif + diff --git a/caffe-crfrnn/src/caffe/util/im2col.cpp b/caffe-crfrnn/src/caffe/util/im2col.cpp new file mode 100644 index 00000000..c48f31f3 --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/im2col.cpp @@ -0,0 +1,83 @@ +#include +#include +#include + +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void im2col_cpu(const Dtype* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + Dtype* data_col) { + int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + int channels_col = channels * kernel_h * kernel_w; + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % kernel_w; + int h_offset = (c / kernel_w) % kernel_h; + int c_im = c / kernel_h / kernel_w; + for (int h = 0; h < height_col; ++h) { + for (int w = 0; w < width_col; ++w) { + int h_pad = h * stride_h - pad_h + h_offset; + int w_pad = w * stride_w - pad_w + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_col[(c * height_col + h) * width_col + w] = + data_im[(c_im * height + h_pad) * width + w_pad]; + else + data_col[(c * height_col + h) * width_col + w] = 0; + } + } + } +} + +// Explicit instantiation +template void im2col_cpu(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, float* data_col); +template void im2col_cpu(const double* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, double* data_col); + +template +void col2im_cpu(const Dtype* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + Dtype* data_im) { + caffe_set(height * width * channels, Dtype(0), data_im); + int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1; + int channels_col = channels * patch_h * patch_w; + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % patch_w; + int h_offset = (c / patch_w) % patch_h; + int c_im = c / patch_h / patch_w; + for (int h = 0; h < height_col; ++h) { + for (int w = 0; w < width_col; ++w) { + int h_pad = h * stride_h - pad_h + h_offset; + int w_pad = w * stride_w - pad_w + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_im[(c_im * height + h_pad) * width + w_pad] += + data_col[(c * height_col + h) * width_col + w]; + } + } + } +} + +// Explicit instantiation +template void col2im_cpu(const float* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, float* data_im); +template void col2im_cpu(const double* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, double* data_im); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/im2col.cu b/caffe-crfrnn/src/caffe/util/im2col.cu new file mode 100644 index 00000000..c90f93eb --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/im2col.cu @@ -0,0 +1,144 @@ +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/util/im2col.hpp" + +namespace caffe { + +template +__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + Dtype* data_col) { + CUDA_KERNEL_LOOP(index, n) { + int w_out = index % width_col; + int h_index = index / width_col; + int h_out = h_index % height_col; + int channel_in = h_index / height_col; + int channel_out = channel_in * kernel_h * kernel_w; + int h_in = h_out * stride_h - pad_h; + int w_in = w_out * stride_w - pad_w; + Dtype* data_col_ptr = data_col; + data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; + const Dtype* data_im_ptr = data_im; + data_im_ptr += (channel_in * height + h_in) * width + w_in; + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + int h = h_in + i; + int w = w_in + j; + *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? + data_im_ptr[i * width + j] : 0; + data_col_ptr += height_col * width_col; + } + } + } +} + +template +void im2col_gpu(const Dtype* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + Dtype* data_col) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int height_col = (height + 2 * pad_h - kernel_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - kernel_w) / stride_w + 1; + int num_kernels = channels * height_col * width_col; + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_gpu_kernel<<>>( + num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h, + pad_w, stride_h, stride_w, height_col, + width_col, data_col); + CUDA_POST_KERNEL_CHECK; +} + + +// Explicit instantiation +template void im2col_gpu(const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + float* data_col); +template void im2col_gpu(const double* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + double* data_col); + +template +__global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, + const int height, const int width, const int channels, + const int patch_h, const int patch_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + Dtype* data_im) { + CUDA_KERNEL_LOOP(index, n) { + Dtype val = 0; + int w = index % width + pad_w; + int h = (index / width) % height + pad_h; + int c = index / (width * height); + // compute the start and end of the output + int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1; + int w_col_end = min(w / stride_w + 1, width_col); + int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1; + int h_col_end = min(h / stride_h + 1, height_col); + /* + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + // the col location: [c * width * height + h_out, w_out] + int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize + + (w - w_col * stride_w); + val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + */ + // equivalent implementation + int offset = + (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col; + int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col; + int coeff_w_col = (1 - stride_w * height_col * width_col); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + data_im[index] = val; + } +} + +template +void col2im_gpu(const Dtype* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, Dtype* data_im) { + int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1; + int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1; + int num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // NOLINT_NEXT_LINE(whitespace/operators) + col2im_gpu_kernel<<>>( + num_kernels, data_col, height, width, channels, patch_h, patch_w, + pad_h, pad_w, stride_h, stride_w, + height_col, width_col, data_im); + CUDA_POST_KERNEL_CHECK; +} + +// Explicit instantiation +template void col2im_gpu(const float* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, float* data_im); +template void col2im_gpu(const double* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_h, const int pad_w, const int stride_h, + const int stride_w, double* data_im); + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/insert_splits.cpp b/caffe-crfrnn/src/caffe/util/insert_splits.cpp new file mode 100644 index 00000000..f20efdae --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/insert_splits.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/util/insert_splits.hpp" + +namespace caffe { + +void InsertSplits(const NetParameter& param, NetParameter* param_split) { + // Initialize by copying from the input NetParameter. + param_split->CopyFrom(param); + param_split->clear_layers(); + map > blob_name_to_last_top_idx; + map, pair > bottom_idx_to_source_top_idx; + map, int> top_idx_to_bottom_count; + map, float> top_idx_to_loss_weight; + map, int> top_idx_to_bottom_split_idx; + map layer_idx_to_layer_name; + layer_idx_to_layer_name[-1] = "input"; + // Determine the number of times each blob is used as an input (bottom) blob. + for (int i = 0; i < param.input_size(); ++i) { + const string& blob_name = param.input(i); + blob_name_to_last_top_idx[blob_name] = make_pair(-1, i); + } + for (int i = 0; i < param.layers_size(); ++i) { + const LayerParameter& layer_param = param.layers(i); + layer_idx_to_layer_name[i] = layer_param.name(); + for (int j = 0; j < layer_param.bottom_size(); ++j) { + const string& blob_name = layer_param.bottom(j); + if (blob_name_to_last_top_idx.find(blob_name) == + blob_name_to_last_top_idx.end()) { + LOG(FATAL) << "Unknown blob input " << blob_name << " to layer " << j; + } + const pair& bottom_idx = make_pair(i, j); + const pair& top_idx = blob_name_to_last_top_idx[blob_name]; + bottom_idx_to_source_top_idx[bottom_idx] = top_idx; + ++top_idx_to_bottom_count[top_idx]; + } + for (int j = 0; j < layer_param.top_size(); ++j) { + const string& blob_name = layer_param.top(j); + blob_name_to_last_top_idx[blob_name] = make_pair(i, j); + } + // A use of a top blob as a loss should be handled similarly to the use of + // a top blob as an input (bottom) blob to another layer. + const int last_loss = + std::min(layer_param.loss_weight_size(), layer_param.top_size()); + for (int j = 0; j < last_loss; ++j) { + const string& blob_name = layer_param.top(j); + const pair& top_idx = blob_name_to_last_top_idx[blob_name]; + top_idx_to_loss_weight[top_idx] = layer_param.loss_weight(j); + if (top_idx_to_loss_weight[top_idx]) { + ++top_idx_to_bottom_count[top_idx]; + } + } + } + // Create split layer for any input blobs used by other layers as bottom + // blobs more than once. + for (int i = 0; i < param.input_size(); ++i) { + const int split_count = top_idx_to_bottom_count[make_pair(-1, i)]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[-1]; + const string& blob_name = param.input(i); + LayerParameter* split_layer_param = param_split->add_layers(); + const float kZeroLossWeight = 0; + ConfigureSplitLayer(layer_name, blob_name, i, split_count, + kZeroLossWeight, split_layer_param); + } + } + for (int i = 0; i < param.layers_size(); ++i) { + LayerParameter* layer_param = param_split->add_layers(); + layer_param->CopyFrom(param.layers(i)); + // Replace any shared bottom blobs with split layer outputs. + for (int j = 0; j < layer_param->bottom_size(); ++j) { + const pair& top_idx = + bottom_idx_to_source_top_idx[make_pair(i, j)]; + const int split_count = top_idx_to_bottom_count[top_idx]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[top_idx.first]; + const string& blob_name = layer_param->bottom(j); + layer_param->set_bottom(j, SplitBlobName(layer_name, + blob_name, top_idx.second, top_idx_to_bottom_split_idx[top_idx]++)); + } + } + // Create split layer for any top blobs used by other layers as bottom + // blobs more than once. + for (int j = 0; j < layer_param->top_size(); ++j) { + const pair& top_idx = make_pair(i, j); + const int split_count = top_idx_to_bottom_count[top_idx]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[i]; + const string& blob_name = layer_param->top(j); + LayerParameter* split_layer_param = param_split->add_layers(); + const float loss_weight = top_idx_to_loss_weight[top_idx]; + ConfigureSplitLayer(layer_name, blob_name, j, split_count, + loss_weight, split_layer_param); + if (loss_weight) { + layer_param->clear_loss_weight(); + top_idx_to_bottom_split_idx[top_idx]++; + } + } + } + } +} + +void ConfigureSplitLayer(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_count, const float loss_weight, + LayerParameter* split_layer_param) { + split_layer_param->Clear(); + split_layer_param->add_bottom(blob_name); + split_layer_param->set_name(SplitLayerName(layer_name, blob_name, blob_idx)); + split_layer_param->set_type(LayerParameter_LayerType_SPLIT); + for (int k = 0; k < split_count; ++k) { + split_layer_param->add_top( + SplitBlobName(layer_name, blob_name, blob_idx, k)); + if (loss_weight) { + if (k == 0) { + split_layer_param->add_loss_weight(loss_weight); + } else { + split_layer_param->add_loss_weight(0); + } + } + } +} + +string SplitLayerName(const string& layer_name, const string& blob_name, + const int blob_idx) { + ostringstream split_layer_name; + split_layer_name << blob_name << "_" << layer_name << "_" << blob_idx + << "_split"; + return split_layer_name.str(); +} + +string SplitBlobName(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_idx) { + ostringstream split_blob_name; + split_blob_name << blob_name << "_" << layer_name << "_" << blob_idx + << "_split_" << split_idx; + return split_blob_name.str(); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/io.cpp b/caffe-crfrnn/src/caffe/util/io.cpp new file mode 100644 index 00000000..b136bc8a --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/io.cpp @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include // NOLINT(readability/streams) +#include +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" + +namespace caffe { + +using google::protobuf::io::FileInputStream; +using google::protobuf::io::FileOutputStream; +using google::protobuf::io::ZeroCopyInputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::ZeroCopyOutputStream; +using google::protobuf::io::CodedOutputStream; +using google::protobuf::Message; + +bool ReadProtoFromTextFile(const char* filename, Message* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + FileInputStream* input = new FileInputStream(fd); + bool success = google::protobuf::TextFormat::Parse(input, proto); + delete input; + close(fd); + return success; +} + +void WriteProtoToTextFile(const Message& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); + FileOutputStream* output = new FileOutputStream(fd); + CHECK(google::protobuf::TextFormat::Print(proto, output)); + delete output; + close(fd); +} + +bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + ZeroCopyInputStream* raw_input = new FileInputStream(fd); + CodedInputStream* coded_input = new CodedInputStream(raw_input); + coded_input->SetTotalBytesLimit(1073741824, 536870912); + + bool success = proto->ParseFromCodedStream(coded_input); + + delete coded_input; + delete raw_input; + close(fd); + return success; +} + +void WriteProtoToBinaryFile(const Message& proto, const char* filename) { + fstream output(filename, ios::out | ios::trunc | ios::binary); + CHECK(proto.SerializeToOstream(&output)); +} + +cv::Mat ReadImageToCVMat(const string& filename, + const int height, const int width, const bool is_color) { + cv::Mat cv_img; + int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : + CV_LOAD_IMAGE_GRAYSCALE); + cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag); + if (!cv_img_origin.data) { + LOG(ERROR) << "Could not open or find file " << filename; + return cv_img_origin; + } + if (height > 0 && width > 0) { + cv::resize(cv_img_origin, cv_img, cv::Size(width, height)); + } else { + cv_img = cv_img_origin; + } + return cv_img; +} + +bool ReadImageToDatum(const string& filename, const int label, + const int height, const int width, const bool is_color, Datum* datum) { + cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color); + if (cv_img.data) { + CVMatToDatum(cv_img, datum); + datum->set_label(label); + return true; + } else { + return false; + } +} + +bool ReadFileToDatum(const string& filename, const int label, + Datum* datum) { + std::streampos size; + + fstream file(filename.c_str(), ios::in|ios::binary|ios::ate); + if (file.is_open()) { + size = file.tellg(); + std::string buffer(size, ' '); + file.seekg(0, ios::beg); + file.read(&buffer[0], size); + file.close(); + datum->set_data(buffer); + datum->set_label(label); + datum->set_encoded(true); + return true; + } else { + return false; + } +} + +cv::Mat DecodeDatumToCVMat(const Datum& datum, + const int height, const int width, const bool is_color) { + cv::Mat cv_img; + CHECK(datum.encoded()) << "Datum not encoded"; + int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : + CV_LOAD_IMAGE_GRAYSCALE); + const string& data = datum.data(); + std::vector vec_data(data.c_str(), data.c_str() + data.size()); + if (height > 0 && width > 0) { + cv::Mat cv_img_origin = cv::imdecode(cv::Mat(vec_data), cv_read_flag); + cv::resize(cv_img_origin, cv_img, cv::Size(width, height)); + } else { + cv_img = cv::imdecode(vec_data, cv_read_flag); + } + if (!cv_img.data) { + LOG(ERROR) << "Could not decode datum "; + } + return cv_img; +} + +// If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum +// if height and width are set it will resize it +// If Datum is not encoded will do nothing +bool DecodeDatum(const int height, const int width, const bool is_color, + Datum* datum) { + if (datum->encoded()) { + cv::Mat cv_img = DecodeDatumToCVMat((*datum), height, width, is_color); + CVMatToDatum(cv_img, datum); + return true; + } else { + return false; + } +} + +void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) { + CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; + datum->set_channels(cv_img.channels()); + datum->set_height(cv_img.rows); + datum->set_width(cv_img.cols); + datum->clear_data(); + datum->clear_float_data(); + datum->set_encoded(false); + int datum_channels = datum->channels(); + int datum_height = datum->height(); + int datum_width = datum->width(); + int datum_size = datum_channels * datum_height * datum_width; + std::string buffer(datum_size, ' '); + for (int h = 0; h < datum_height; ++h) { + const uchar* ptr = cv_img.ptr(h); + int img_index = 0; + for (int w = 0; w < datum_width; ++w) { + for (int c = 0; c < datum_channels; ++c) { + int datum_index = (c * datum_height + h) * datum_width + w; + buffer[datum_index] = static_cast(ptr[img_index++]); + } + } + } + datum->set_data(buffer); +} + +// Verifies format of data stored in HDF5 file and reshapes blob accordingly. +template +void hdf5_load_nd_dataset_helper( + hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, + Blob* blob) { + // Verify that the dataset exists. + CHECK(H5LTfind_dataset(file_id, dataset_name_)) + << "Failed to find HDF5 dataset " << dataset_name_; + // Verify that the number of dimensions is in the accepted range. + herr_t status; + int ndims; + status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims); + CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_; + CHECK_GE(ndims, min_dim); + CHECK_LE(ndims, max_dim); + + // Verify that the data format is what we expect: float or double. + std::vector dims(ndims); + H5T_class_t class_; + status = H5LTget_dataset_info( + file_id, dataset_name_, dims.data(), &class_, NULL); + CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_; + CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data"; + + blob->Reshape( + dims[0], + (dims.size() > 1) ? dims[1] : 1, + (dims.size() > 2) ? dims[2] : 1, + (dims.size() > 3) ? dims[3] : 1); +} + +template <> +void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, + int min_dim, int max_dim, Blob* blob) { + hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob); + herr_t status = H5LTread_dataset_float( + file_id, dataset_name_, blob->mutable_cpu_data()); + CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_; +} + +template <> +void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, + int min_dim, int max_dim, Blob* blob) { + hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob); + herr_t status = H5LTread_dataset_double( + file_id, dataset_name_, blob->mutable_cpu_data()); + CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_; +} + +template <> +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob) { + hsize_t dims[HDF5_NUM_DIMS]; + dims[0] = blob.num(); + dims[1] = blob.channels(); + dims[2] = blob.height(); + dims[3] = blob.width(); + herr_t status = H5LTmake_dataset_float( + file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data()); + CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name; +} + +template <> +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob) { + hsize_t dims[HDF5_NUM_DIMS]; + dims[0] = blob.num(); + dims[1] = blob.channels(); + dims[2] = blob.height(); + dims[3] = blob.width(); + herr_t status = H5LTmake_dataset_double( + file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data()); + CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name; +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/math_functions.cpp b/caffe-crfrnn/src/caffe/util/math_functions.cpp new file mode 100644 index 00000000..13e17be5 --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/math_functions.cpp @@ -0,0 +1,387 @@ +#include +#include + +#include + +#include "caffe/common.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/rng.hpp" + +namespace caffe { + +template<> +void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const float alpha, const float* A, const float* B, const float beta, + float* C) { + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, + ldb, beta, C, N); +} + +template<> +void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const double alpha, const double* A, const double* B, const double beta, + double* C) { + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, + ldb, beta, C, N); +} + +template <> +void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, + const int N, const float alpha, const float* A, const float* x, + const float beta, float* y) { + cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); +} + +template <> +void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, + const int N, const double alpha, const double* A, const double* x, + const double beta, double* y) { + cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); +} + +template <> +void caffe_axpy(const int N, const float alpha, const float* X, + float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); } + +template <> +void caffe_axpy(const int N, const double alpha, const double* X, + double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); } + +template +void caffe_set(const int N, const Dtype alpha, Dtype* Y) { + if (alpha == 0) { + memset(Y, 0, sizeof(Dtype) * N); // NOLINT(caffe/alt_fn) + return; + } + for (int i = 0; i < N; ++i) { + Y[i] = alpha; + } +} + +template void caffe_set(const int N, const int alpha, int* Y); +template void caffe_set(const int N, const float alpha, float* Y); +template void caffe_set(const int N, const double alpha, double* Y); + +template <> +void caffe_add_scalar(const int N, const float alpha, float* Y) { + for (int i = 0; i < N; ++i) { + Y[i] += alpha; + } +} + +template <> +void caffe_add_scalar(const int N, const double alpha, double* Y) { + for (int i = 0; i < N; ++i) { + Y[i] += alpha; + } +} + +template +void caffe_copy(const int N, const Dtype* X, Dtype* Y) { + if (X != Y) { + if (Caffe::mode() == Caffe::GPU) { +#ifndef CPU_ONLY + // NOLINT_NEXT_LINE(caffe/alt_fn) + CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault)); +#else + NO_GPU; +#endif + } else { + memcpy(Y, X, sizeof(Dtype) * N); // NOLINT(caffe/alt_fn) + } + } +} + +template void caffe_copy(const int N, const int* X, int* Y); +template void caffe_copy(const int N, const unsigned int* X, + unsigned int* Y); +template void caffe_copy(const int N, const float* X, float* Y); +template void caffe_copy(const int N, const double* X, double* Y); + +template <> +void caffe_scal(const int N, const float alpha, float *X) { + cblas_sscal(N, alpha, X, 1); +} + +template <> +void caffe_scal(const int N, const double alpha, double *X) { + cblas_dscal(N, alpha, X, 1); +} + +template <> +void caffe_cpu_axpby(const int N, const float alpha, const float* X, + const float beta, float* Y) { + cblas_saxpby(N, alpha, X, 1, beta, Y, 1); +} + +template <> +void caffe_cpu_axpby(const int N, const double alpha, const double* X, + const double beta, double* Y) { + cblas_daxpby(N, alpha, X, 1, beta, Y, 1); +} + +template <> +void caffe_add(const int n, const float* a, const float* b, + float* y) { + vsAdd(n, a, b, y); +} + +template <> +void caffe_add(const int n, const double* a, const double* b, + double* y) { + vdAdd(n, a, b, y); +} + +template <> +void caffe_sub(const int n, const float* a, const float* b, + float* y) { + vsSub(n, a, b, y); +} + +template <> +void caffe_sub(const int n, const double* a, const double* b, + double* y) { + vdSub(n, a, b, y); +} + +template <> +void caffe_mul(const int n, const float* a, const float* b, + float* y) { + vsMul(n, a, b, y); +} + +template <> +void caffe_mul(const int n, const double* a, const double* b, + double* y) { + vdMul(n, a, b, y); +} + +template <> +void caffe_div(const int n, const float* a, const float* b, + float* y) { + vsDiv(n, a, b, y); +} + +template <> +void caffe_div(const int n, const double* a, const double* b, + double* y) { + vdDiv(n, a, b, y); +} + +template <> +void caffe_powx(const int n, const float* a, const float b, + float* y) { + vsPowx(n, a, b, y); +} + +template <> +void caffe_powx(const int n, const double* a, const double b, + double* y) { + vdPowx(n, a, b, y); +} + +template <> +void caffe_sqr(const int n, const float* a, float* y) { + vsSqr(n, a, y); +} + +template <> +void caffe_sqr(const int n, const double* a, double* y) { + vdSqr(n, a, y); +} + +template <> +void caffe_exp(const int n, const float* a, float* y) { + vsExp(n, a, y); +} + +template <> +void caffe_exp(const int n, const double* a, double* y) { + vdExp(n, a, y); +} + +template <> +void caffe_abs(const int n, const float* a, float* y) { + vsAbs(n, a, y); +} + +template <> +void caffe_abs(const int n, const double* a, double* y) { + vdAbs(n, a, y); +} + +unsigned int caffe_rng_rand() { + return (*caffe_rng())(); +} + +template +Dtype caffe_nextafter(const Dtype b) { + return boost::math::nextafter( + b, std::numeric_limits::max()); +} + +template +float caffe_nextafter(const float b); + +template +double caffe_nextafter(const double b); + +template +void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_LE(a, b); + boost::uniform_real random_distribution(a, caffe_nextafter(b)); + boost::variate_generator > + variate_generator(caffe_rng(), random_distribution); + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template +void caffe_rng_uniform(const int n, const float a, const float b, + float* r); + +template +void caffe_rng_uniform(const int n, const double a, const double b, + double* r); + +template +void caffe_rng_gaussian(const int n, const Dtype a, + const Dtype sigma, Dtype* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GT(sigma, 0); + boost::normal_distribution random_distribution(a, sigma); + boost::variate_generator > + variate_generator(caffe_rng(), random_distribution); + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template +void caffe_rng_gaussian(const int n, const float mu, + const float sigma, float* r); + +template +void caffe_rng_gaussian(const int n, const double mu, + const double sigma, double* r); + +template +void caffe_rng_bernoulli(const int n, const Dtype p, int* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); + boost::bernoulli_distribution random_distribution(p); + boost::variate_generator > + variate_generator(caffe_rng(), random_distribution); + for (int i = 0; i < n; ++i) { + r[i] = variate_generator(); + } +} + +template +void caffe_rng_bernoulli(const int n, const double p, int* r); + +template +void caffe_rng_bernoulli(const int n, const float p, int* r); + +template +void caffe_rng_bernoulli(const int n, const Dtype p, unsigned int* r) { + CHECK_GE(n, 0); + CHECK(r); + CHECK_GE(p, 0); + CHECK_LE(p, 1); + boost::bernoulli_distribution random_distribution(p); + boost::variate_generator > + variate_generator(caffe_rng(), random_distribution); + for (int i = 0; i < n; ++i) { + r[i] = static_cast(variate_generator()); + } +} + +template +void caffe_rng_bernoulli(const int n, const double p, unsigned int* r); + +template +void caffe_rng_bernoulli(const int n, const float p, unsigned int* r); + +template <> +float caffe_cpu_strided_dot(const int n, const float* x, const int incx, + const float* y, const int incy) { + return cblas_sdot(n, x, incx, y, incy); +} + +template <> +double caffe_cpu_strided_dot(const int n, const double* x, + const int incx, const double* y, const int incy) { + return cblas_ddot(n, x, incx, y, incy); +} + +template +Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y) { + return caffe_cpu_strided_dot(n, x, 1, y, 1); +} + +template +float caffe_cpu_dot(const int n, const float* x, const float* y); + +template +double caffe_cpu_dot(const int n, const double* x, const double* y); + +template <> +int caffe_cpu_hamming_distance(const int n, const float* x, + const float* y) { + int dist = 0; + for (int i = 0; i < n; ++i) { + dist += __builtin_popcount(static_cast(x[i]) ^ + static_cast(y[i])); + } + return dist; +} + +template <> +int caffe_cpu_hamming_distance(const int n, const double* x, + const double* y) { + int dist = 0; + for (int i = 0; i < n; ++i) { + dist += __builtin_popcountl(static_cast(x[i]) ^ + static_cast(y[i])); + } + return dist; +} + +template <> +float caffe_cpu_asum(const int n, const float* x) { + return cblas_sasum(n, x, 1); +} + +template <> +double caffe_cpu_asum(const int n, const double* x) { + return cblas_dasum(n, x, 1); +} + +template <> +void caffe_cpu_scale(const int n, const float alpha, const float *x, + float* y) { + cblas_scopy(n, x, 1, y, 1); + cblas_sscal(n, alpha, y, 1); +} + +template <> +void caffe_cpu_scale(const int n, const double alpha, const double *x, + double* y) { + cblas_dcopy(n, x, 1, y, 1); + cblas_dscal(n, alpha, y, 1); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/math_functions.cu b/caffe-crfrnn/src/caffe/util/math_functions.cu new file mode 100644 index 00000000..43e65eb9 --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/math_functions.cu @@ -0,0 +1,444 @@ +#include // CUDA's, not caffe's, for fabs, signbit +#include +#include // thrust::plus +#include + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template <> +void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const float alpha, const float* A, const float* B, const float beta, + float* C) { + // Note that cublas follows fortran order. + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA, + N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, + const double alpha, const double* A, const double* B, const double beta, + double* C) { + // Note that cublas follows fortran order. + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA, + N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, + const int N, const float alpha, const float* A, const float* x, + const float beta, float* y) { + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, + A, N, x, 1, &beta, y, 1)); +} + +template <> +void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, + const int N, const double alpha, const double* A, const double* x, + const double beta, double* y) { + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha, + A, N, x, 1, &beta, y, 1)); +} + +template <> +void caffe_gpu_axpy(const int N, const float alpha, const float* X, + float* Y) { + CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +template <> +void caffe_gpu_axpy(const int N, const double alpha, const double* X, + double* Y) { + CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) { + if (X != Y) { + CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault)); // NOLINT(caffe/alt_fn) + } +} + +template <> +void caffe_gpu_scal(const int N, const float alpha, float *X) { + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); +} + +template <> +void caffe_gpu_scal(const int N, const double alpha, double *X) { + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); +} + +template <> +void caffe_gpu_axpby(const int N, const float alpha, const float* X, + const float beta, float* Y) { + caffe_gpu_scal(N, beta, Y); + caffe_gpu_axpy(N, alpha, X, Y); +} + +template <> +void caffe_gpu_axpby(const int N, const double alpha, const double* X, + const double beta, double* Y) { + caffe_gpu_scal(N, beta, Y); + caffe_gpu_axpy(N, alpha, X, Y); +} + +template <> +void caffe_gpu_dot(const int n, const float* x, const float* y, + float* out) { + CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out)); +} + +template <> +void caffe_gpu_dot(const int n, const double* x, const double* y, + double * out) { + CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out)); +} + +template <> +void caffe_gpu_asum(const int n, const float* x, float* y) { + CUBLAS_CHECK(cublasSasum(Caffe::cublas_handle(), n, x, 1, y)); +} + +template <> +void caffe_gpu_asum(const int n, const double* x, double* y) { + CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y)); +} + +template <> +void caffe_gpu_scale(const int n, const float alpha, const float *x, + float* y) { + CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1)); + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1)); +} + +template <> +void caffe_gpu_scale(const int n, const double alpha, const double *x, + double* y) { + CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1)); + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1)); +} + +template +__global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = alpha; + } +} + +template +void caffe_gpu_set(const int N, const Dtype alpha, Dtype* Y) { + if (alpha == 0) { + CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N)); // NOLINT(caffe/alt_fn) + return; + } + // NOLINT_NEXT_LINE(whitespace/operators) + set_kernel<<>>( + N, alpha, Y); +} + +template void caffe_gpu_set(const int N, const int alpha, int* Y); +template void caffe_gpu_set(const int N, const float alpha, float* Y); +template void caffe_gpu_set(const int N, const double alpha, double* Y); + +template +__global__ void add_scalar_kernel(const int n, const Dtype alpha, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] += alpha; + } +} + +template <> +void caffe_gpu_add_scalar(const int N, const float alpha, float* Y) { + // NOLINT_NEXT_LINE(whitespace/operators) + add_scalar_kernel<<>>( + N, alpha, Y); +} + +template <> +void caffe_gpu_add_scalar(const int N, const double alpha, double* Y) { + // NOLINT_NEXT_LINE(whitespace/operators) + add_scalar_kernel<<>>( + N, alpha, Y); +} + +template +__global__ void add_kernel(const int n, const Dtype* a, + const Dtype* b, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = a[index] + b[index]; + } +} + +template <> +void caffe_gpu_add(const int N, const float* a, const float* b, + float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + add_kernel<<>>( + N, a, b, y); +} + +template <> +void caffe_gpu_add(const int N, const double* a, const double* b, + double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + add_kernel<<>>( + N, a, b, y); +} + +template +__global__ void sub_kernel(const int n, const Dtype* a, + const Dtype* b, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = a[index] - b[index]; + } +} + +template <> +void caffe_gpu_sub(const int N, const float* a, const float* b, + float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + sub_kernel<<>>( + N, a, b, y); +} + +template <> +void caffe_gpu_sub(const int N, const double* a, const double* b, + double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + sub_kernel<<>>( + N, a, b, y); +} + +template +__global__ void mul_kernel(const int n, const Dtype* a, + const Dtype* b, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = a[index] * b[index]; + } +} + +template <> +void caffe_gpu_mul(const int N, const float* a, + const float* b, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + mul_kernel<<>>( + N, a, b, y); +} + +template <> +void caffe_gpu_mul(const int N, const double* a, + const double* b, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + mul_kernel<<>>( + N, a, b, y); +} + +template +__global__ void div_kernel(const int n, const Dtype* a, + const Dtype* b, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = a[index] / b[index]; + } +} + +template <> +void caffe_gpu_div(const int N, const float* a, + const float* b, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + div_kernel<<>>( + N, a, b, y); +} + +template <> +void caffe_gpu_div(const int N, const double* a, + const double* b, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + div_kernel<<>>( + N, a, b, y); +} + +template +__global__ void abs_kernel(const int n, const Dtype* a, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = abs(a[index]); + } +} + +template <> +void caffe_gpu_abs(const int N, const float* a, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + abs_kernel<<>>( + N, a, y); +} + +template <> +void caffe_gpu_abs(const int N, const double* a, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + abs_kernel<<>>( + N, a, y); +} + + +template +__global__ void exp_kernel(const int n, const Dtype* a, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = exp(a[index]); + } +} + +template <> +void caffe_gpu_exp(const int N, const float* a, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + exp_kernel<<>>( + N, a, y); +} + +template <> +void caffe_gpu_exp(const int N, const double* a, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + exp_kernel<<>>( + N, a, y); +} + +template +__global__ void powx_kernel(const int n, const Dtype* a, + const Dtype alpha, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = pow(a[index], alpha); + } +} + +template <> +void caffe_gpu_powx(const int N, const float* a, + const float alpha, float* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + powx_kernel<<>>( + N, a, alpha, y); +} + +template <> +void caffe_gpu_powx(const int N, const double* a, + const double alpha, double* y) { + // NOLINT_NEXT_LINE(whitespace/operators) + powx_kernel<<>>( + N, a, alpha, y); +} + +DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sign, y[index] = (Dtype(0) < x[index]) + - (x[index] < Dtype(0))); +DEFINE_AND_INSTANTIATE_GPU_UNARY_FUNC(sgnbit, y[index] = signbit(x[index])); + +__global__ void popc_kernel(const int n, const float* a, + const float* b, uint8_t* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = __popc(static_cast(a[index]) ^ + static_cast(b[index])); + } +} + +__global__ void popcll_kernel(const int n, const double* a, + const double* b, uint8_t* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = __popcll(static_cast(a[index]) ^ + static_cast(b[index])); + } +} + +template <> +uint32_t caffe_gpu_hamming_distance(const int n, const float* x, + const float* y) { + // TODO: Fix caffe_gpu_hamming_distance (see failing unit test + // TestHammingDistanceGPU in test_math_functions.cpp). + NOT_IMPLEMENTED; + thrust::device_vector popcounts(n); + // NOLINT_NEXT_LINE(whitespace/operators) + popc_kernel<<>>( + n, x, y, thrust::raw_pointer_cast(popcounts.data())); + return thrust::reduce(popcounts.begin(), popcounts.end(), + (uint32_t) 0, thrust::plus()); +} + +template <> +uint32_t caffe_gpu_hamming_distance(const int n, const double* x, + const double* y) { + // TODO: Fix caffe_gpu_hamming_distance (see failing unit test + // TestHammingDistanceGPU in test_math_functions.cpp). + NOT_IMPLEMENTED; + thrust::device_vector popcounts(n); + // NOLINT_NEXT_LINE(whitespace/operators) + popcll_kernel<<>>( + n, x, y, thrust::raw_pointer_cast(popcounts.data())); + return thrust::reduce(popcounts.begin(), popcounts.end(), + /* NOLINT_NEXT_LINE(build/include_what_you_use) */ + (uint32_t) 0, thrust::plus()); +} + +void caffe_gpu_rng_uniform(const int n, unsigned int* r) { + CURAND_CHECK(curandGenerate(Caffe::curand_generator(), r, n)); +} + +template <> +void caffe_gpu_rng_uniform(const int n, const float a, const float b, + float* r) { + CURAND_CHECK(curandGenerateUniform(Caffe::curand_generator(), r, n)); + const float range = b - a; + if (range != static_cast(1)) { + caffe_gpu_scal(n, range, r); + } + if (a != static_cast(0)) { + caffe_gpu_add_scalar(n, a, r); + } +} + +template <> +void caffe_gpu_rng_uniform(const int n, const double a, const double b, + double* r) { + CURAND_CHECK(curandGenerateUniformDouble(Caffe::curand_generator(), r, n)); + const double range = b - a; + if (range != static_cast(1)) { + caffe_gpu_scal(n, range, r); + } + if (a != static_cast(0)) { + caffe_gpu_add_scalar(n, a, r); + } +} + +template <> +void caffe_gpu_rng_gaussian(const int n, const float mu, const float sigma, + float* r) { + CURAND_CHECK( + curandGenerateNormal(Caffe::curand_generator(), r, n, mu, sigma)); +} + +template <> +void caffe_gpu_rng_gaussian(const int n, const double mu, const double sigma, + double* r) { + CURAND_CHECK( + curandGenerateNormalDouble(Caffe::curand_generator(), r, n, mu, sigma)); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/caffe/util/modified_permutohedral.cpp b/caffe-crfrnn/src/caffe/util/modified_permutohedral.cpp new file mode 100755 index 00000000..5b6d4642 --- /dev/null +++ b/caffe-crfrnn/src/caffe/util/modified_permutohedral.cpp @@ -0,0 +1,744 @@ +//#include "stdafx.h" +#include "caffe/util/modified_permutohedral.hpp" + +#ifdef __SSE__ +// SSE Permutoheral lattice +# define SSE_PERMUTOHEDRAL +#endif + +#if defined(SSE_PERMUTOHEDRAL) +# include +# include +# ifdef __SSE4_1__ +# include +# endif +#endif + +namespace caffe { +/************************************************/ +/*** Hash Table ***/ +/************************************************/ + +class HashTableCopy{ +protected: + size_t key_size_, filled_, capacity_; + std::vector< short > keys_; + std::vector< int > table_; + void grow(){ + // Create the new memory and copy the values in + int old_capacity = capacity_; + capacity_ *= 2; + std::vector old_keys( (old_capacity+10)*key_size_ ); + std::copy( keys_.begin(), keys_.end(), old_keys.begin() ); + std::vector old_table( capacity_, -1 ); + + // Swap the memory + table_.swap( old_table ); + keys_.swap( old_keys ); + + // Reinsert each element + for( int i=0; i= 0){ + int e = old_table[i]; + size_t h = hash( getKey(e) ) % capacity_; + for(; table_[h] >= 0; h = h= capacity_) grow(); + // Get the hash value + size_t h = hash( k ) % capacity_; + // Find the element with he right key, using linear probing + while(1){ + int e = table_[h]; + if (e==-1){ + if (create){ + // Insert a new key and return the new id + for( size_t i=0; i0; j-- ){ + __m128 cf = f[j-1]*scale_factor[j-1]; + elevated[j] = sm - _mm_set1_ps(j)*cf; + sm += cf; + } + elevated[0] = sm; + + // Find the closest 0-colored simplex through rounding + __m128 sum = Zero; + for( int i=0; i<=d_; i++ ){ + __m128 v = invdplus1 * elevated[i]; +#ifdef __SSE4_1__ + v = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT ); +#else + v = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) ); +#endif + rem0[i] = v*dplus1; + sum += v; + } + + // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) + for( int i=0; i<=d_; i++ ) + rank[i] = Zero; + for( int i=0; i0; j-- ){ + float cf = f[j-1]*scale_factor[j-1]; + elevated[j] = sm - j*cf; + sm += cf; + } + elevated[0] = sm; + + // Find the closest 0-colored simplex through rounding + float down_factor = 1.0f / (d_+1); + float up_factor = (d_+1); + int sum = 0; + for( int i=0; i<=d_; i++ ){ + //int rd1 = round( down_factor * elevated[i]); + int rd2; + float v = down_factor * elevated[i]; + float up = ceilf(v)*up_factor; + float down = floorf(v)*up_factor; + if (up - elevated[i] < elevated[i] - down) rd2 = (short)up; + else rd2 = (short)down; + + //if(rd1!=rd2) + // break; + + rem0[i] = rd2; + sum += rd2*down_factor; + } + + // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) + for( int i=0; i<=d_; i++ ) + rank[i] = 0; + for( int i=0; i d_ ){ + rank[i] -= d_+1; + rem0[i] -= d_+1; + } + } + + // Compute the barycentric coordinates (p.10 in [Adams etal 2010]) + for( int i=0; i<=d_+1; i++ ) + barycentric[i] = 0; + for( int i=0; i<=d_; i++ ){ + float v = (elevated[i] - rem0[i])*down_factor; + barycentric[d_-rank[i] ] += v; + barycentric[d_-rank[i]+1] -= v; + } + // Wrap around + barycentric[0] += 1.0 + barycentric[d_+1]; + + // Compute all vertices and their offset + for( int remainder=0; remainder<=d_; remainder++ ){ + for( int i=0; i 0 (used for blurring) + float * values = new float[ (M_+2)*value_size ]; + float * new_values = new float[ (M_+2)*value_size ]; + + for( int i=0; i<(M_+2)*value_size; i++ ) + values[i] = new_values[i] = 0; + + // Splatting + for( int i=0; i=0; reverse?j--:j++ ){ + for( int i=0; i 0 (used for blurring) + float * values = new float[ (M_+2)*value_size ]; + float * new_values = new float[ (M_+2)*value_size ]; + + for( int i=0; i<(M_+2)*value_size; i++ ) + values[i] = new_values[i] = 0; + + // Splatting + for( int i=0; i(in[k*N_ + i]); + } + } + + for( int j=reverse?d_:0; j<=d_ && j>=0; reverse?j--:j++ ){ + for( int i=0; i 0 (used for blurring) + __m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 ); + __m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); + __m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); + + __m128 Zero = _mm_set1_ps( 0 ); + + for( int i=0; i<(M_+2)*sse_value_size; i++ ) + values[i] = new_values[i] = Zero; + for( int i=0; i=0; reverse?j--:j++ ){ + for( int i=0; i 0 (used for blurring) + __m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 ); + __m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); + __m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); + + __m128 Zero = _mm_set1_ps( 0 ); + + for( int i=0; i<(M_+2)*sse_value_size; i++ ) + values[i] = new_values[i] = Zero; + for( int i=0; i(in[s*N_ + i]); + } + memcpy(sse_val, sdp_temp, value_size*sizeof(float)); + + for( int j=0; j<=d_; j++ ){ + int o = offset_[i*(d_+1)+j]+1; + __m128 w = _mm_set1_ps( barycentric_[i*(d_+1)+j] ); + for( int k=0; k=0; reverse?j--:j++ ){ + for( int i=0; i +#include +#include + +#include +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +namespace caffe { + +bool NetNeedsUpgrade(const NetParameter& net_param) { + for (int i = 0; i < net_param.layers_size(); ++i) { + if (net_param.layers(i).has_layer()) { + return true; + } + } + return false; +} + +bool UpgradeV0Net(const NetParameter& v0_net_param_padding_layers, + NetParameter* net_param) { + // First upgrade padding layers to padded conv layers. + NetParameter v0_net_param; + UpgradeV0PaddingLayers(v0_net_param_padding_layers, &v0_net_param); + // Now upgrade layer parameters. + bool is_fully_compatible = true; + net_param->Clear(); + if (v0_net_param.has_name()) { + net_param->set_name(v0_net_param.name()); + } + for (int i = 0; i < v0_net_param.layers_size(); ++i) { + is_fully_compatible &= UpgradeLayerParameter(v0_net_param.layers(i), + net_param->add_layers()); + } + for (int i = 0; i < v0_net_param.input_size(); ++i) { + net_param->add_input(v0_net_param.input(i)); + } + for (int i = 0; i < v0_net_param.input_dim_size(); ++i) { + net_param->add_input_dim(v0_net_param.input_dim(i)); + } + if (v0_net_param.has_force_backward()) { + net_param->set_force_backward(v0_net_param.force_backward()); + } + return is_fully_compatible; +} + +void UpgradeV0PaddingLayers(const NetParameter& param, + NetParameter* param_upgraded_pad) { + // Copy everything other than the layers from the original param. + param_upgraded_pad->Clear(); + param_upgraded_pad->CopyFrom(param); + param_upgraded_pad->clear_layers(); + // Figure out which layer each bottom blob comes from. + map blob_name_to_last_top_idx; + for (int i = 0; i < param.input_size(); ++i) { + const string& blob_name = param.input(i); + blob_name_to_last_top_idx[blob_name] = -1; + } + for (int i = 0; i < param.layers_size(); ++i) { + const LayerParameter& layer_connection = param.layers(i); + const V0LayerParameter& layer_param = layer_connection.layer(); + // Add the layer to the new net, unless it's a padding layer. + if (layer_param.type() != "padding") { + param_upgraded_pad->add_layers()->CopyFrom(layer_connection); + } + for (int j = 0; j < layer_connection.bottom_size(); ++j) { + const string& blob_name = layer_connection.bottom(j); + if (blob_name_to_last_top_idx.find(blob_name) == + blob_name_to_last_top_idx.end()) { + LOG(FATAL) << "Unknown blob input " << blob_name << " to layer " << j; + } + const int top_idx = blob_name_to_last_top_idx[blob_name]; + if (top_idx == -1) { + continue; + } + LayerParameter source_layer = param.layers(top_idx); + if (source_layer.layer().type() == "padding") { + // This layer has a padding layer as input -- check that it is a conv + // layer or a pooling layer and takes only one input. Also check that + // the padding layer input has only one input and one output. Other + // cases have undefined behavior in Caffe. + CHECK((layer_param.type() == "conv") || (layer_param.type() == "pool")) + << "Padding layer input to " + "non-convolutional / non-pooling layer type " + << layer_param.type(); + CHECK_EQ(layer_connection.bottom_size(), 1) + << "Conv Layer takes a single blob as input."; + CHECK_EQ(source_layer.bottom_size(), 1) + << "Padding Layer takes a single blob as input."; + CHECK_EQ(source_layer.top_size(), 1) + << "Padding Layer produces a single blob as output."; + int layer_index = param_upgraded_pad->layers_size() - 1; + param_upgraded_pad->mutable_layers(layer_index)->mutable_layer() + ->set_pad(source_layer.layer().pad()); + param_upgraded_pad->mutable_layers(layer_index) + ->set_bottom(j, source_layer.bottom(0)); + } + } + for (int j = 0; j < layer_connection.top_size(); ++j) { + const string& blob_name = layer_connection.top(j); + blob_name_to_last_top_idx[blob_name] = i; + } + } +} + +bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, + LayerParameter* layer_param) { + bool is_fully_compatible = true; + layer_param->Clear(); + for (int i = 0; i < v0_layer_connection.bottom_size(); ++i) { + layer_param->add_bottom(v0_layer_connection.bottom(i)); + } + for (int i = 0; i < v0_layer_connection.top_size(); ++i) { + layer_param->add_top(v0_layer_connection.top(i)); + } + if (v0_layer_connection.has_layer()) { + const V0LayerParameter& v0_layer_param = v0_layer_connection.layer(); + if (v0_layer_param.has_name()) { + layer_param->set_name(v0_layer_param.name()); + } + const string& type = v0_layer_param.type(); + if (v0_layer_param.has_type()) { + layer_param->set_type(UpgradeV0LayerType(type)); + } + for (int i = 0; i < v0_layer_param.blobs_size(); ++i) { + layer_param->add_blobs()->CopyFrom(v0_layer_param.blobs(i)); + } + for (int i = 0; i < v0_layer_param.blobs_lr_size(); ++i) { + layer_param->add_blobs_lr(v0_layer_param.blobs_lr(i)); + } + for (int i = 0; i < v0_layer_param.weight_decay_size(); ++i) { + layer_param->add_weight_decay(v0_layer_param.weight_decay(i)); + } + if (v0_layer_param.has_num_output()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_num_output( + v0_layer_param.num_output()); + } else if (type == "innerproduct") { + layer_param->mutable_inner_product_param()->set_num_output( + v0_layer_param.num_output()); + } else { + LOG(ERROR) << "Unknown parameter num_output for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_biasterm()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_bias_term( + v0_layer_param.biasterm()); + } else if (type == "innerproduct") { + layer_param->mutable_inner_product_param()->set_bias_term( + v0_layer_param.biasterm()); + } else { + LOG(ERROR) << "Unknown parameter biasterm for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_weight_filler()) { + if (type == "conv") { + layer_param->mutable_convolution_param()-> + mutable_weight_filler()->CopyFrom(v0_layer_param.weight_filler()); + } else if (type == "innerproduct") { + layer_param->mutable_inner_product_param()-> + mutable_weight_filler()->CopyFrom(v0_layer_param.weight_filler()); + } else { + LOG(ERROR) << "Unknown parameter weight_filler for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_bias_filler()) { + if (type == "conv") { + layer_param->mutable_convolution_param()-> + mutable_bias_filler()->CopyFrom(v0_layer_param.bias_filler()); + } else if (type == "innerproduct") { + layer_param->mutable_inner_product_param()-> + mutable_bias_filler()->CopyFrom(v0_layer_param.bias_filler()); + } else { + LOG(ERROR) << "Unknown parameter bias_filler for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_pad()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_pad(v0_layer_param.pad()); + } else if (type == "pool") { + layer_param->mutable_pooling_param()->set_pad(v0_layer_param.pad()); + } else { + LOG(ERROR) << "Unknown parameter pad for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_kernelsize()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_kernel_size( + v0_layer_param.kernelsize()); + } else if (type == "pool") { + layer_param->mutable_pooling_param()->set_kernel_size( + v0_layer_param.kernelsize()); + } else { + LOG(ERROR) << "Unknown parameter kernelsize for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_group()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_group( + v0_layer_param.group()); + } else { + LOG(ERROR) << "Unknown parameter group for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_stride()) { + if (type == "conv") { + layer_param->mutable_convolution_param()->set_stride( + v0_layer_param.stride()); + } else if (type == "pool") { + layer_param->mutable_pooling_param()->set_stride( + v0_layer_param.stride()); + } else { + LOG(ERROR) << "Unknown parameter stride for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_pool()) { + if (type == "pool") { + V0LayerParameter_PoolMethod pool = v0_layer_param.pool(); + switch (pool) { + case V0LayerParameter_PoolMethod_MAX: + layer_param->mutable_pooling_param()->set_pool( + PoolingParameter_PoolMethod_MAX); + break; + case V0LayerParameter_PoolMethod_AVE: + layer_param->mutable_pooling_param()->set_pool( + PoolingParameter_PoolMethod_AVE); + break; + case V0LayerParameter_PoolMethod_STOCHASTIC: + layer_param->mutable_pooling_param()->set_pool( + PoolingParameter_PoolMethod_STOCHASTIC); + break; + default: + LOG(ERROR) << "Unknown pool method " << pool; + is_fully_compatible = false; + } + } else { + LOG(ERROR) << "Unknown parameter pool for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_dropout_ratio()) { + if (type == "dropout") { + layer_param->mutable_dropout_param()->set_dropout_ratio( + v0_layer_param.dropout_ratio()); + } else { + LOG(ERROR) << "Unknown parameter dropout_ratio for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_local_size()) { + if (type == "lrn") { + layer_param->mutable_lrn_param()->set_local_size( + v0_layer_param.local_size()); + } else { + LOG(ERROR) << "Unknown parameter local_size for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_alpha()) { + if (type == "lrn") { + layer_param->mutable_lrn_param()->set_alpha(v0_layer_param.alpha()); + } else { + LOG(ERROR) << "Unknown parameter alpha for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_beta()) { + if (type == "lrn") { + layer_param->mutable_lrn_param()->set_beta(v0_layer_param.beta()); + } else { + LOG(ERROR) << "Unknown parameter beta for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_k()) { + if (type == "lrn") { + layer_param->mutable_lrn_param()->set_k(v0_layer_param.k()); + } else { + LOG(ERROR) << "Unknown parameter k for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_source()) { + if (type == "data") { + layer_param->mutable_data_param()->set_source(v0_layer_param.source()); + } else if (type == "hdf5_data") { + layer_param->mutable_hdf5_data_param()->set_source( + v0_layer_param.source()); + } else if (type == "images") { + layer_param->mutable_image_data_param()->set_source( + v0_layer_param.source()); + } else if (type == "window_data") { + layer_param->mutable_window_data_param()->set_source( + v0_layer_param.source()); + } else if (type == "infogain_loss") { + layer_param->mutable_infogain_loss_param()->set_source( + v0_layer_param.source()); + } else { + LOG(ERROR) << "Unknown parameter source for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_scale()) { + layer_param->mutable_transform_param()-> + set_scale(v0_layer_param.scale()); + } + if (v0_layer_param.has_meanfile()) { + layer_param->mutable_transform_param()-> + set_mean_file(v0_layer_param.meanfile()); + } + if (v0_layer_param.has_batchsize()) { + if (type == "data") { + layer_param->mutable_data_param()->set_batch_size( + v0_layer_param.batchsize()); + } else if (type == "hdf5_data") { + layer_param->mutable_hdf5_data_param()->set_batch_size( + v0_layer_param.batchsize()); + } else if (type == "images") { + layer_param->mutable_image_data_param()->set_batch_size( + v0_layer_param.batchsize()); + } else if (type == "window_data") { + layer_param->mutable_window_data_param()->set_batch_size( + v0_layer_param.batchsize()); + } else { + LOG(ERROR) << "Unknown parameter batchsize for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_cropsize()) { + layer_param->mutable_transform_param()-> + set_crop_size(v0_layer_param.cropsize()); + } + if (v0_layer_param.has_mirror()) { + layer_param->mutable_transform_param()-> + set_mirror(v0_layer_param.mirror()); + } + if (v0_layer_param.has_rand_skip()) { + if (type == "data") { + layer_param->mutable_data_param()->set_rand_skip( + v0_layer_param.rand_skip()); + } else if (type == "images") { + layer_param->mutable_image_data_param()->set_rand_skip( + v0_layer_param.rand_skip()); + } else { + LOG(ERROR) << "Unknown parameter rand_skip for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_shuffle_images()) { + if (type == "images") { + layer_param->mutable_image_data_param()->set_shuffle( + v0_layer_param.shuffle_images()); + } else { + LOG(ERROR) << "Unknown parameter shuffle for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_new_height()) { + if (type == "images") { + layer_param->mutable_image_data_param()->set_new_height( + v0_layer_param.new_height()); + } else { + LOG(ERROR) << "Unknown parameter new_height for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_new_width()) { + if (type == "images") { + layer_param->mutable_image_data_param()->set_new_width( + v0_layer_param.new_width()); + } else { + LOG(ERROR) << "Unknown parameter new_width for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_concat_dim()) { + if (type == "concat") { + layer_param->mutable_concat_param()->set_concat_dim( + v0_layer_param.concat_dim()); + } else { + LOG(ERROR) << "Unknown parameter concat_dim for layer type " << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_det_fg_threshold()) { + if (type == "window_data") { + layer_param->mutable_window_data_param()->set_fg_threshold( + v0_layer_param.det_fg_threshold()); + } else { + LOG(ERROR) << "Unknown parameter det_fg_threshold for layer type " + << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_det_bg_threshold()) { + if (type == "window_data") { + layer_param->mutable_window_data_param()->set_bg_threshold( + v0_layer_param.det_bg_threshold()); + } else { + LOG(ERROR) << "Unknown parameter det_bg_threshold for layer type " + << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_det_fg_fraction()) { + if (type == "window_data") { + layer_param->mutable_window_data_param()->set_fg_fraction( + v0_layer_param.det_fg_fraction()); + } else { + LOG(ERROR) << "Unknown parameter det_fg_fraction for layer type " + << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_det_context_pad()) { + if (type == "window_data") { + layer_param->mutable_window_data_param()->set_context_pad( + v0_layer_param.det_context_pad()); + } else { + LOG(ERROR) << "Unknown parameter det_context_pad for layer type " + << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_det_crop_mode()) { + if (type == "window_data") { + layer_param->mutable_window_data_param()->set_crop_mode( + v0_layer_param.det_crop_mode()); + } else { + LOG(ERROR) << "Unknown parameter det_crop_mode for layer type " + << type; + is_fully_compatible = false; + } + } + if (v0_layer_param.has_hdf5_output_param()) { + if (type == "hdf5_output") { + layer_param->mutable_hdf5_output_param()->CopyFrom( + v0_layer_param.hdf5_output_param()); + } else { + LOG(ERROR) << "Unknown parameter hdf5_output_param for layer type " + << type; + is_fully_compatible = false; + } + } + } + return is_fully_compatible; +} + +LayerParameter_LayerType UpgradeV0LayerType(const string& type) { + if (type == "accuracy") { + return LayerParameter_LayerType_ACCURACY; + } else if (type == "bnll") { + return LayerParameter_LayerType_BNLL; + } else if (type == "concat") { + return LayerParameter_LayerType_CONCAT; + } else if (type == "conv") { + return LayerParameter_LayerType_CONVOLUTION; + } else if (type == "data") { + return LayerParameter_LayerType_DATA; + } else if (type == "dropout") { + return LayerParameter_LayerType_DROPOUT; + } else if (type == "euclidean_loss") { + return LayerParameter_LayerType_EUCLIDEAN_LOSS; + } else if (type == "flatten") { + return LayerParameter_LayerType_FLATTEN; + } else if (type == "hdf5_data") { + return LayerParameter_LayerType_HDF5_DATA; + } else if (type == "hdf5_output") { + return LayerParameter_LayerType_HDF5_OUTPUT; + } else if (type == "im2col") { + return LayerParameter_LayerType_IM2COL; + } else if (type == "images") { + return LayerParameter_LayerType_IMAGE_DATA; + } else if (type == "infogain_loss") { + return LayerParameter_LayerType_INFOGAIN_LOSS; + } else if (type == "innerproduct") { + return LayerParameter_LayerType_INNER_PRODUCT; + } else if (type == "lrn") { + return LayerParameter_LayerType_LRN; + } else if (type == "multinomial_logistic_loss") { + return LayerParameter_LayerType_MULTINOMIAL_LOGISTIC_LOSS; + } else if (type == "pool") { + return LayerParameter_LayerType_POOLING; + } else if (type == "relu") { + return LayerParameter_LayerType_RELU; + } else if (type == "sigmoid") { + return LayerParameter_LayerType_SIGMOID; + } else if (type == "softmax") { + return LayerParameter_LayerType_SOFTMAX; + } else if (type == "softmax_loss") { + return LayerParameter_LayerType_SOFTMAX_LOSS; + } else if (type == "split") { + return LayerParameter_LayerType_SPLIT; + } else if (type == "tanh") { + return LayerParameter_LayerType_TANH; + } else if (type == "window_data") { + return LayerParameter_LayerType_WINDOW_DATA; + } else { + LOG(FATAL) << "Unknown layer name: " << type; + return LayerParameter_LayerType_NONE; + } +} + +bool NetNeedsDataUpgrade(const NetParameter& net_param) { + for (int i = 0; i < net_param.layers_size(); ++i) { + if (net_param.layers(i).type() == LayerParameter_LayerType_DATA) { + DataParameter layer_param = net_param.layers(i).data_param(); + if (layer_param.has_scale()) { return true; } + if (layer_param.has_mean_file()) { return true; } + if (layer_param.has_crop_size()) { return true; } + if (layer_param.has_mirror()) { return true; } + } + if (net_param.layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) { + ImageDataParameter layer_param = net_param.layers(i).image_data_param(); + if (layer_param.has_scale()) { return true; } + if (layer_param.has_mean_file()) { return true; } + if (layer_param.has_crop_size()) { return true; } + if (layer_param.has_mirror()) { return true; } + } + if (net_param.layers(i).type() == LayerParameter_LayerType_WINDOW_DATA) { + WindowDataParameter layer_param = net_param.layers(i).window_data_param(); + if (layer_param.has_scale()) { return true; } + if (layer_param.has_mean_file()) { return true; } + if (layer_param.has_crop_size()) { return true; } + if (layer_param.has_mirror()) { return true; } + } + } + return false; +} + +#define CONVERT_LAYER_TRANSFORM_PARAM(TYPE, Name, param_name) \ + do { \ + if (net_param->layers(i).type() == LayerParameter_LayerType_##TYPE) { \ + Name##Parameter* layer_param = \ + net_param->mutable_layers(i)->mutable_##param_name##_param(); \ + TransformationParameter* transform_param = \ + net_param->mutable_layers(i)->mutable_transform_param(); \ + if (layer_param->has_scale()) { \ + transform_param->set_scale(layer_param->scale()); \ + layer_param->clear_scale(); \ + } \ + if (layer_param->has_mean_file()) { \ + transform_param->set_mean_file(layer_param->mean_file()); \ + layer_param->clear_mean_file(); \ + } \ + if (layer_param->has_crop_size()) { \ + transform_param->set_crop_size(layer_param->crop_size()); \ + layer_param->clear_crop_size(); \ + } \ + if (layer_param->has_mirror()) { \ + transform_param->set_mirror(layer_param->mirror()); \ + layer_param->clear_mirror(); \ + } \ + } \ + } while (0) + +void UpgradeNetDataTransformation(NetParameter* net_param) { + for (int i = 0; i < net_param->layers_size(); ++i) { + CONVERT_LAYER_TRANSFORM_PARAM(DATA, Data, data); + CONVERT_LAYER_TRANSFORM_PARAM(IMAGE_DATA, ImageData, image_data); + CONVERT_LAYER_TRANSFORM_PARAM(WINDOW_DATA, WindowData, window_data); + } +} + +void NetParameterToPrettyPrint(const NetParameter& param, + NetParameterPrettyPrint* pretty_param) { + pretty_param->Clear(); + if (param.has_name()) { + pretty_param->set_name(param.name()); + } + if (param.has_force_backward()) { + pretty_param->set_force_backward(param.force_backward()); + } + for (int i = 0; i < param.input_size(); ++i) { + pretty_param->add_input(param.input(i)); + } + for (int i = 0; i < param.input_dim_size(); ++i) { + pretty_param->add_input_dim(param.input_dim(i)); + } + for (int i = 0; i < param.layers_size(); ++i) { + pretty_param->add_layers()->CopyFrom(param.layers(i)); + } +} + +void UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { + if (NetNeedsUpgrade(*param)) { + // NetParameter was specified using the old style (V0LayerParameter); try to + // upgrade it. + LOG(ERROR) << "Attempting to upgrade input file specified using deprecated " + << "V0LayerParameter: " << param_file; + NetParameter original_param(*param); + if (!UpgradeV0Net(original_param, param)) { + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "V0NetParameter to NetParameter (see above); continuing anyway."; + } else { + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "V0LayerParameter"; + } + LOG(ERROR) << "Note that future Caffe releases will not support " + << "V0NetParameter; use ./build/tools/upgrade_net_proto_text for " + << "prototxt and ./build/tools/upgrade_net_proto_binary for model " + << "weights upgrade this and any other net protos to the new format."; + } + // NetParameter uses old style data transformation fields; try to upgrade it. + if (NetNeedsDataUpgrade(*param)) { + LOG(ERROR) << "Attempting to upgrade input file specified using deprecated " + << "transformation parameters: " << param_file; + UpgradeNetDataTransformation(param); + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "data transformation parameters."; + LOG(ERROR) << "Note that future Caffe releases will only support " + << "transform_param messages for transformation fields."; + } +} + +void ReadNetParamsFromTextFileOrDie(const string& param_file, + NetParameter* param) { + CHECK(ReadProtoFromTextFile(param_file, param)) + << "Failed to parse NetParameter file: " << param_file; + UpgradeNetAsNeeded(param_file, param); +} + +void ReadNetParamsFromBinaryFileOrDie(const string& param_file, + NetParameter* param) { + CHECK(ReadProtoFromBinaryFile(param_file, param)) + << "Failed to parse NetParameter file: " << param_file; + UpgradeNetAsNeeded(param_file, param); +} + +} // namespace caffe diff --git a/caffe-crfrnn/src/gtest/CMakeLists.txt b/caffe-crfrnn/src/gtest/CMakeLists.txt new file mode 100644 index 00000000..ef7ff7ed --- /dev/null +++ b/caffe-crfrnn/src/gtest/CMakeLists.txt @@ -0,0 +1,5 @@ +add_library(gtest STATIC EXCLUDE_FROM_ALL gtest.h gtest-all.cpp) +caffe_default_properties(gtest) + +#add_library(gtest_main gtest_main.cc) +#target_link_libraries(gtest_main gtest) diff --git a/caffe-crfrnn/src/gtest/gtest-all.cpp b/caffe-crfrnn/src/gtest/gtest-all.cpp new file mode 100644 index 00000000..92619741 --- /dev/null +++ b/caffe-crfrnn/src/gtest/gtest-all.cpp @@ -0,0 +1,9117 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// +// Google C++ Testing Framework (Google Test) +// +// Sometimes it's desirable to build Google Test by compiling a single file. +// This file serves this purpose. + +// This line ensures that gtest.h can be compiled on its own, even +// when it's fused. +#include "gtest/gtest.h" + +// The following lines pull in the real gtest *.cc files. +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) + +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// Utilities for testing Google Test itself and code that uses Google Test +// (e.g. frameworks built on top of Google Test). + +#ifndef GTEST_INCLUDE_GTEST_GTEST_SPI_H_ +#define GTEST_INCLUDE_GTEST_GTEST_SPI_H_ + + +namespace testing { + +// This helper class can be used to mock out Google Test failure reporting +// so that we can test Google Test or code that builds on Google Test. +// +// An object of this class appends a TestPartResult object to the +// TestPartResultArray object given in the constructor whenever a Google Test +// failure is reported. It can either intercept only failures that are +// generated in the same thread that created this object or it can intercept +// all generated failures. The scope of this mock object can be controlled with +// the second argument to the two arguments constructor. +class GTEST_API_ ScopedFakeTestPartResultReporter + : public TestPartResultReporterInterface { + public: + // The two possible mocking modes of this object. + enum InterceptMode { + INTERCEPT_ONLY_CURRENT_THREAD, // Intercepts only thread local failures. + INTERCEPT_ALL_THREADS // Intercepts all failures. + }; + + // The c'tor sets this object as the test part result reporter used + // by Google Test. The 'result' parameter specifies where to report the + // results. This reporter will only catch failures generated in the current + // thread. DEPRECATED + explicit ScopedFakeTestPartResultReporter(TestPartResultArray* result); + + // Same as above, but you can choose the interception scope of this object. + ScopedFakeTestPartResultReporter(InterceptMode intercept_mode, + TestPartResultArray* result); + + // The d'tor restores the previous test part result reporter. + virtual ~ScopedFakeTestPartResultReporter(); + + // Appends the TestPartResult object to the TestPartResultArray + // received in the constructor. + // + // This method is from the TestPartResultReporterInterface + // interface. + virtual void ReportTestPartResult(const TestPartResult& result); + private: + void Init(); + + const InterceptMode intercept_mode_; + TestPartResultReporterInterface* old_reporter_; + TestPartResultArray* const result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedFakeTestPartResultReporter); +}; + +namespace internal { + +// A helper class for implementing EXPECT_FATAL_FAILURE() and +// EXPECT_NONFATAL_FAILURE(). Its destructor verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +class GTEST_API_ SingleFailureChecker { + public: + // The constructor remembers the arguments. + SingleFailureChecker(const TestPartResultArray* results, + TestPartResult::Type type, + const string& substr); + ~SingleFailureChecker(); + private: + const TestPartResultArray* const results_; + const TestPartResult::Type type_; + const string substr_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(SingleFailureChecker); +}; + +} // namespace internal + +} // namespace testing + +// A set of macros for testing Google Test assertions or code that's expected +// to generate Google Test fatal failures. It verifies that the given +// statement will cause exactly one fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_FATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_FATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - 'statement' cannot reference local non-static variables or +// non-static members of the current object. +// - 'statement' cannot return a value. +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. The AcceptsMacroThatExpandsToUnprotectedComma test in +// gtest_unittest.cc will fail to compile if we do that. +#define EXPECT_FATAL_FAILURE(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ALL_THREADS, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +// A macro for testing Google Test assertions or code that's expected to +// generate Google Test non-fatal failures. It asserts that the given +// statement will cause exactly one non-fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_NONFATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// 'statement' is allowed to reference local variables and members of +// the current object. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. If we do that, the code won't compile when the user gives +// EXPECT_NONFATAL_FAILURE() a statement that contains a macro that +// expands to code containing an unprotected comma. The +// AcceptsMacroThatExpandsToUnprotectedComma test in gtest_unittest.cc +// catches that. +// +// For the same reason, we have to write +// if (::testing::internal::AlwaysTrue()) { statement; } +// instead of +// GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) +// to avoid an MSVC warning on unreachable code. +#define EXPECT_NONFATAL_FAILURE(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter::INTERCEPT_ALL_THREADS,\ + >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#endif // GTEST_INCLUDE_GTEST_GTEST_SPI_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // NOLINT +#include +#include + +#if GTEST_OS_LINUX + +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +# define GTEST_HAS_GETTIMEOFDAY_ 1 + +# include // NOLINT +# include // NOLINT +# include // NOLINT +// Declares vsnprintf(). This header is not available on Windows. +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include + +#elif GTEST_OS_SYMBIAN +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT + +#elif GTEST_OS_ZOS +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT + +// On z/OS we additionally need strings.h for strcasecmp. +# include // NOLINT + +#elif GTEST_OS_WINDOWS_MOBILE // We are on Windows CE. + +# include // NOLINT + +#elif GTEST_OS_WINDOWS // We are on Windows proper. + +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT + +# if GTEST_OS_WINDOWS_MINGW +// MinGW has gettimeofday() but not _ftime64(). +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +// TODO(kenton@google.com): There are other ways to get the time on +// Windows, like GetTickCount() or GetSystemTimeAsFileTime(). MinGW +// supports these. consider using them instead. +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT +# endif // GTEST_OS_WINDOWS_MINGW + +// cpplint thinks that the header is already included, so we want to +// silence it. +# include // NOLINT + +#else + +// Assume other platforms have gettimeofday(). +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +# define GTEST_HAS_GETTIMEOFDAY_ 1 + +// cpplint thinks that the header is already included, so we want to +// silence it. +# include // NOLINT +# include // NOLINT + +#endif // GTEST_OS_LINUX + +#if GTEST_HAS_EXCEPTIONS +# include +#endif + +#if GTEST_CAN_STREAM_RESULTS_ +# include // NOLINT +# include // NOLINT +#endif + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Utility functions and classes used by the Google C++ testing framework. +// +// Author: wan@google.com (Zhanyong Wan) +// +// This file contains purely Google Test's internal implementation. Please +// DO NOT #INCLUDE IT IN A USER PROGRAM. + +#ifndef GTEST_SRC_GTEST_INTERNAL_INL_H_ +#define GTEST_SRC_GTEST_INTERNAL_INL_H_ + +// GTEST_IMPLEMENTATION_ is defined to 1 iff the current translation unit is +// part of Google Test's implementation; otherwise it's undefined. +#if !GTEST_IMPLEMENTATION_ +// A user is trying to include this from his code - just say no. +# error "gtest-internal-inl.h is part of Google Test's internal implementation." +# error "It must not be included except by Google Test itself." +#endif // GTEST_IMPLEMENTATION_ + +#ifndef _WIN32_WCE +# include +#endif // !_WIN32_WCE +#include +#include // For strtoll/_strtoul64/malloc/free. +#include // For memmove. + +#include +#include +#include + + +#if GTEST_OS_WINDOWS +# include // NOLINT +#endif // GTEST_OS_WINDOWS + + +namespace testing { + +// Declares the flags. +// +// We don't want the users to modify this flag in the code, but want +// Google Test's own unit tests to be able to access it. Therefore we +// declare it here as opposed to in gtest.h. +GTEST_DECLARE_bool_(death_test_use_fork); + +namespace internal { + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +GTEST_API_ extern const TypeId kTestTypeIdInGoogleTest; + +// Names of the flags (needed for parsing Google Test flags). +const char kAlsoRunDisabledTestsFlag[] = "also_run_disabled_tests"; +const char kBreakOnFailureFlag[] = "break_on_failure"; +const char kCatchExceptionsFlag[] = "catch_exceptions"; +const char kColorFlag[] = "color"; +const char kFilterFlag[] = "filter"; +const char kListTestsFlag[] = "list_tests"; +const char kOutputFlag[] = "output"; +const char kPrintTimeFlag[] = "print_time"; +const char kRandomSeedFlag[] = "random_seed"; +const char kRepeatFlag[] = "repeat"; +const char kShuffleFlag[] = "shuffle"; +const char kStackTraceDepthFlag[] = "stack_trace_depth"; +const char kStreamResultToFlag[] = "stream_result_to"; +const char kThrowOnFailureFlag[] = "throw_on_failure"; + +// A valid random seed must be in [1, kMaxRandomSeed]. +const int kMaxRandomSeed = 99999; + +// g_help_flag is true iff the --help flag or an equivalent form is +// specified on the command line. +GTEST_API_ extern bool g_help_flag; + +// Returns the current time in milliseconds. +GTEST_API_ TimeInMillis GetTimeInMillis(); + +// Returns true iff Google Test should use colors in the output. +GTEST_API_ bool ShouldUseColor(bool stdout_is_tty); + +// Formats the given time in milliseconds as seconds. +GTEST_API_ std::string FormatTimeInMillisAsSeconds(TimeInMillis ms); + +// Parses a string for an Int32 flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +GTEST_API_ bool ParseInt32Flag( + const char* str, const char* flag, Int32* value); + +// Returns a random seed in range [1, kMaxRandomSeed] based on the +// given --gtest_random_seed flag value. +inline int GetRandomSeedFromFlag(Int32 random_seed_flag) { + const unsigned int raw_seed = (random_seed_flag == 0) ? + static_cast(GetTimeInMillis()) : + static_cast(random_seed_flag); + + // Normalizes the actual seed to range [1, kMaxRandomSeed] such that + // it's easy to type. + const int normalized_seed = + static_cast((raw_seed - 1U) % + static_cast(kMaxRandomSeed)) + 1; + return normalized_seed; +} + +// Returns the first valid random seed after 'seed'. The behavior is +// undefined if 'seed' is invalid. The seed after kMaxRandomSeed is +// considered to be 1. +inline int GetNextRandomSeed(int seed) { + GTEST_CHECK_(1 <= seed && seed <= kMaxRandomSeed) + << "Invalid random seed " << seed << " - must be in [1, " + << kMaxRandomSeed << "]."; + const int next_seed = seed + 1; + return (next_seed > kMaxRandomSeed) ? 1 : next_seed; +} + +// This class saves the values of all Google Test flags in its c'tor, and +// restores them in its d'tor. +class GTestFlagSaver { + public: + // The c'tor. + GTestFlagSaver() { + also_run_disabled_tests_ = GTEST_FLAG(also_run_disabled_tests); + break_on_failure_ = GTEST_FLAG(break_on_failure); + catch_exceptions_ = GTEST_FLAG(catch_exceptions); + color_ = GTEST_FLAG(color); + death_test_style_ = GTEST_FLAG(death_test_style); + death_test_use_fork_ = GTEST_FLAG(death_test_use_fork); + filter_ = GTEST_FLAG(filter); + internal_run_death_test_ = GTEST_FLAG(internal_run_death_test); + list_tests_ = GTEST_FLAG(list_tests); + output_ = GTEST_FLAG(output); + print_time_ = GTEST_FLAG(print_time); + random_seed_ = GTEST_FLAG(random_seed); + repeat_ = GTEST_FLAG(repeat); + shuffle_ = GTEST_FLAG(shuffle); + stack_trace_depth_ = GTEST_FLAG(stack_trace_depth); + stream_result_to_ = GTEST_FLAG(stream_result_to); + throw_on_failure_ = GTEST_FLAG(throw_on_failure); + } + + // The d'tor is not virtual. DO NOT INHERIT FROM THIS CLASS. + ~GTestFlagSaver() { + GTEST_FLAG(also_run_disabled_tests) = also_run_disabled_tests_; + GTEST_FLAG(break_on_failure) = break_on_failure_; + GTEST_FLAG(catch_exceptions) = catch_exceptions_; + GTEST_FLAG(color) = color_; + GTEST_FLAG(death_test_style) = death_test_style_; + GTEST_FLAG(death_test_use_fork) = death_test_use_fork_; + GTEST_FLAG(filter) = filter_; + GTEST_FLAG(internal_run_death_test) = internal_run_death_test_; + GTEST_FLAG(list_tests) = list_tests_; + GTEST_FLAG(output) = output_; + GTEST_FLAG(print_time) = print_time_; + GTEST_FLAG(random_seed) = random_seed_; + GTEST_FLAG(repeat) = repeat_; + GTEST_FLAG(shuffle) = shuffle_; + GTEST_FLAG(stack_trace_depth) = stack_trace_depth_; + GTEST_FLAG(stream_result_to) = stream_result_to_; + GTEST_FLAG(throw_on_failure) = throw_on_failure_; + } + private: + // Fields for saving the original values of flags. + bool also_run_disabled_tests_; + bool break_on_failure_; + bool catch_exceptions_; + String color_; + String death_test_style_; + bool death_test_use_fork_; + String filter_; + String internal_run_death_test_; + bool list_tests_; + String output_; + bool print_time_; + internal::Int32 random_seed_; + internal::Int32 repeat_; + bool shuffle_; + internal::Int32 stack_trace_depth_; + String stream_result_to_; + bool throw_on_failure_; +} GTEST_ATTRIBUTE_UNUSED_; + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type UInt32 because wchar_t may not be +// wide enough to contain a code point. +// The output buffer str must containt at least 32 characters. +// The function returns the address of the output buffer. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. +GTEST_API_ char* CodePointToUtf8(UInt32 code_point, char* str); + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin, Symbian OS) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +GTEST_API_ String WideStringToUtf8(const wchar_t* str, int num_chars); + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded(); + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (e.g., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +GTEST_API_ bool ShouldShard(const char* total_shards_str, + const char* shard_index_str, + bool in_subprocess_for_death_test); + +// Parses the environment variable var as an Int32. If it is unset, +// returns default_val. If it is not an Int32, prints an error and +// and aborts. +GTEST_API_ Int32 Int32FromEnvOrDie(const char* env_var, Int32 default_val); + +// Given the total number of shards, the shard index, and the test id, +// returns true iff the test should be run on this shard. The test id is +// some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +GTEST_API_ bool ShouldRunTestOnShard( + int total_shards, int shard_index, int test_id); + +// STL container utilities. + +// Returns the number of elements in the given container that satisfy +// the given predicate. +template +inline int CountIf(const Container& c, Predicate predicate) { + // Implemented as an explicit loop since std::count_if() in libCstd on + // Solaris has a non-standard signature. + int count = 0; + for (typename Container::const_iterator it = c.begin(); it != c.end(); ++it) { + if (predicate(*it)) + ++count; + } + return count; +} + +// Applies a function/functor to each element in the container. +template +void ForEach(const Container& c, Functor functor) { + std::for_each(c.begin(), c.end(), functor); +} + +// Returns the i-th element of the vector, or default_value if i is not +// in range [0, v.size()). +template +inline E GetElementOr(const std::vector& v, int i, E default_value) { + return (i < 0 || i >= static_cast(v.size())) ? default_value : v[i]; +} + +// Performs an in-place shuffle of a range of the vector's elements. +// 'begin' and 'end' are element indices as an STL-style range; +// i.e. [begin, end) are shuffled, where 'end' == size() means to +// shuffle to the end of the vector. +template +void ShuffleRange(internal::Random* random, int begin, int end, + std::vector* v) { + const int size = static_cast(v->size()); + GTEST_CHECK_(0 <= begin && begin <= size) + << "Invalid shuffle range start " << begin << ": must be in range [0, " + << size << "]."; + GTEST_CHECK_(begin <= end && end <= size) + << "Invalid shuffle range finish " << end << ": must be in range [" + << begin << ", " << size << "]."; + + // Fisher-Yates shuffle, from + // http://en.wikipedia.org/wiki/Fisher-Yates_shuffle + for (int range_width = end - begin; range_width >= 2; range_width--) { + const int last_in_range = begin + range_width - 1; + const int selected = begin + random->Generate(range_width); + std::swap((*v)[selected], (*v)[last_in_range]); + } +} + +// Performs an in-place shuffle of the vector's elements. +template +inline void Shuffle(internal::Random* random, std::vector* v) { + ShuffleRange(random, 0, static_cast(v->size()), v); +} + +// A function for deleting an object. Handy for being used as a +// functor. +template +static void Delete(T* x) { + delete x; +} + +// A predicate that checks the key of a TestProperty against a known key. +// +// TestPropertyKeyIs is copyable. +class TestPropertyKeyIs { + public: + // Constructor. + // + // TestPropertyKeyIs has NO default constructor. + explicit TestPropertyKeyIs(const char* key) + : key_(key) {} + + // Returns true iff the test name of test property matches on key_. + bool operator()(const TestProperty& test_property) const { + return String(test_property.key()).Compare(key_) == 0; + } + + private: + String key_; +}; + +// Class UnitTestOptions. +// +// This class contains functions for processing options the user +// specifies when running the tests. It has only static members. +// +// In most cases, the user can specify an option using either an +// environment variable or a command line flag. E.g. you can set the +// test filter using either GTEST_FILTER or --gtest_filter. If both +// the variable and the flag are present, the latter overrides the +// former. +class GTEST_API_ UnitTestOptions { + public: + // Functions for processing the gtest_output flag. + + // Returns the output format, or "" for normal printed output. + static String GetOutputFormat(); + + // Returns the absolute path of the requested output file, or the + // default (test_detail.xml in the original working directory) if + // none was explicitly specified. + static String GetAbsolutePathToOutputFile(); + + // Functions for processing the gtest_filter flag. + + // Returns true iff the wildcard pattern matches the string. The + // first ':' or '\0' character in pattern marks the end of it. + // + // This recursive algorithm isn't very efficient, but is clear and + // works well enough for matching test names, which are short. + static bool PatternMatchesString(const char *pattern, const char *str); + + // Returns true iff the user-specified filter matches the test case + // name and the test name. + static bool FilterMatchesTest(const String &test_case_name, + const String &test_name); + +#if GTEST_OS_WINDOWS + // Function for supporting the gtest_catch_exception flag. + + // Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the + // given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. + // This function is useful as an __except condition. + static int GTestShouldProcessSEH(DWORD exception_code); +#endif // GTEST_OS_WINDOWS + + // Returns true if "name" matches the ':' separated list of glob-style + // filters in "filter". + static bool MatchesFilter(const String& name, const char* filter); +}; + +// Returns the current application's name, removing directory path if that +// is present. Used by UnitTestOptions::GetOutputFile. +GTEST_API_ FilePath GetCurrentExecutableName(); + +// The role interface for getting the OS stack trace as a string. +class OsStackTraceGetterInterface { + public: + OsStackTraceGetterInterface() {} + virtual ~OsStackTraceGetterInterface() {} + + // Returns the current OS stack trace as a String. Parameters: + // + // max_depth - the maximum number of stack frames to be included + // in the trace. + // skip_count - the number of top frames to be skipped; doesn't count + // against max_depth. + virtual String CurrentStackTrace(int max_depth, int skip_count) = 0; + + // UponLeavingGTest() should be called immediately before Google Test calls + // user code. It saves some information about the current stack that + // CurrentStackTrace() will use to find and hide Google Test stack frames. + virtual void UponLeavingGTest() = 0; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetterInterface); +}; + +// A working implementation of the OsStackTraceGetterInterface interface. +class OsStackTraceGetter : public OsStackTraceGetterInterface { + public: + OsStackTraceGetter() : caller_frame_(NULL) {} + virtual String CurrentStackTrace(int max_depth, int skip_count); + virtual void UponLeavingGTest(); + + // This string is inserted in place of stack frames that are part of + // Google Test's implementation. + static const char* const kElidedFramesMarker; + + private: + Mutex mutex_; // protects all internal state + + // We save the stack frame below the frame that calls user code. + // We do this because the address of the frame immediately below + // the user code changes between the call to UponLeavingGTest() + // and any calls to CurrentStackTrace() from within the user code. + void* caller_frame_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetter); +}; + +// Information about a Google Test trace point. +struct TraceInfo { + const char* file; + int line; + String message; +}; + +// This is the default global test part result reporter used in UnitTestImpl. +// This class should only be used by UnitTestImpl. +class DefaultGlobalTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultGlobalTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. Reports the test part + // result in the current test. + virtual void ReportTestPartResult(const TestPartResult& result); + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultGlobalTestPartResultReporter); +}; + +// This is the default per thread test part result reporter used in +// UnitTestImpl. This class should only be used by UnitTestImpl. +class DefaultPerThreadTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultPerThreadTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. The implementation just + // delegates to the current global test part result reporter of *unit_test_. + virtual void ReportTestPartResult(const TestPartResult& result); + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultPerThreadTestPartResultReporter); +}; + +// The private implementation of the UnitTest class. We don't protect +// the methods under a mutex, as this class is not accessible by a +// user and the UnitTest class that delegates work to this class does +// proper locking. +class GTEST_API_ UnitTestImpl { + public: + explicit UnitTestImpl(UnitTest* parent); + virtual ~UnitTestImpl(); + + // There are two different ways to register your own TestPartResultReporter. + // You can register your own repoter to listen either only for test results + // from the current thread or for results from all threads. + // By default, each per-thread test result repoter just passes a new + // TestPartResult to the global test result reporter, which registers the + // test part result for the currently running test. + + // Returns the global test part result reporter. + TestPartResultReporterInterface* GetGlobalTestPartResultReporter(); + + // Sets the global test part result reporter. + void SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter); + + // Returns the test part result reporter for the current thread. + TestPartResultReporterInterface* GetTestPartResultReporterForCurrentThread(); + + // Sets the test part result reporter for the current thread. + void SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter); + + // Gets the number of successful test cases. + int successful_test_case_count() const; + + // Gets the number of failed test cases. + int failed_test_case_count() const; + + // Gets the number of all test cases. + int total_test_case_count() const; + + // Gets the number of all test cases that contain at least one test + // that should run. + int test_case_to_run_count() const; + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns true iff the unit test passed (i.e. all test cases passed). + bool Passed() const { return !Failed(); } + + // Returns true iff the unit test failed (i.e. some test case failed + // or something outside of all tests failed). + bool Failed() const { + return failed_test_case_count() > 0 || ad_hoc_test_result()->Failed(); + } + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + const TestCase* GetTestCase(int i) const { + const int index = GetElementOr(test_case_indices_, i, -1); + return index < 0 ? NULL : test_cases_[i]; + } + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + TestCase* GetMutableTestCase(int i) { + const int index = GetElementOr(test_case_indices_, i, -1); + return index < 0 ? NULL : test_cases_[index]; + } + + // Provides access to the event listener list. + TestEventListeners* listeners() { return &listeners_; } + + // Returns the TestResult for the test that's currently running, or + // the TestResult for the ad hoc test if no test is running. + TestResult* current_test_result(); + + // Returns the TestResult for the ad hoc test. + const TestResult* ad_hoc_test_result() const { return &ad_hoc_test_result_; } + + // Sets the OS stack trace getter. + // + // Does nothing if the input and the current OS stack trace getter + // are the same; otherwise, deletes the old getter and makes the + // input the current getter. + void set_os_stack_trace_getter(OsStackTraceGetterInterface* getter); + + // Returns the current OS stack trace getter if it is not NULL; + // otherwise, creates an OsStackTraceGetter, makes it the current + // getter, and returns it. + OsStackTraceGetterInterface* os_stack_trace_getter(); + + // Returns the current OS stack trace as a String. + // + // The maximum number of stack frames to be included is specified by + // the gtest_stack_trace_depth flag. The skip_count parameter + // specifies the number of top frames to be skipped, which doesn't + // count against the number of frames to be included. + // + // For example, if Foo() calls Bar(), which in turn calls + // CurrentOsStackTraceExceptTop(1), Foo() will be included in the + // trace but Bar() and CurrentOsStackTraceExceptTop() won't. + String CurrentOsStackTraceExceptTop(int skip_count); + + // Finds and returns a TestCase with the given name. If one doesn't + // exist, creates one and returns it. + // + // Arguments: + // + // test_case_name: name of the test case + // type_param: the name of the test's type parameter, or NULL if + // this is not a typed or a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + TestCase* GetTestCase(const char* test_case_name, + const char* type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc); + + // Adds a TestInfo to the unit test. + // + // Arguments: + // + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + // test_info: the TestInfo object + void AddTestInfo(Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc, + TestInfo* test_info) { + // In order to support thread-safe death tests, we need to + // remember the original working directory when the test program + // was first invoked. We cannot do this in RUN_ALL_TESTS(), as + // the user may have changed the current directory before calling + // RUN_ALL_TESTS(). Therefore we capture the current directory in + // AddTestInfo(), which is called to register a TEST or TEST_F + // before main() is reached. + if (original_working_dir_.IsEmpty()) { + original_working_dir_.Set(FilePath::GetCurrentDir()); + GTEST_CHECK_(!original_working_dir_.IsEmpty()) + << "Failed to get the current working directory."; + } + + GetTestCase(test_info->test_case_name(), + test_info->type_param(), + set_up_tc, + tear_down_tc)->AddTestInfo(test_info); + } + +#if GTEST_HAS_PARAM_TEST + // Returns ParameterizedTestCaseRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + internal::ParameterizedTestCaseRegistry& parameterized_test_registry() { + return parameterized_test_registry_; + } +#endif // GTEST_HAS_PARAM_TEST + + // Sets the TestCase object for the test that's currently running. + void set_current_test_case(TestCase* a_current_test_case) { + current_test_case_ = a_current_test_case; + } + + // Sets the TestInfo object for the test that's currently running. If + // current_test_info is NULL, the assertion results will be stored in + // ad_hoc_test_result_. + void set_current_test_info(TestInfo* a_current_test_info) { + current_test_info_ = a_current_test_info; + } + + // Registers all parameterized tests defined using TEST_P and + // INSTANTIATE_TEST_CASE_P, creating regular tests for each test/parameter + // combination. This method can be called more then once; it has guards + // protecting from registering the tests more then once. If + // value-parameterized tests are disabled, RegisterParameterizedTests is + // present but does nothing. + void RegisterParameterizedTests(); + + // Runs all tests in this UnitTest object, prints the result, and + // returns true if all tests are successful. If any exception is + // thrown during a test, this test is considered to be failed, but + // the rest of the tests will still be run. + bool RunAllTests(); + + // Clears the results of all tests, except the ad hoc tests. + void ClearNonAdHocTestResult() { + ForEach(test_cases_, TestCase::ClearTestCaseResult); + } + + // Clears the results of ad-hoc test assertions. + void ClearAdHocTestResult() { + ad_hoc_test_result_.Clear(); + } + + enum ReactionToSharding { + HONOR_SHARDING_PROTOCOL, + IGNORE_SHARDING_PROTOCOL + }; + + // Matches the full name of each test against the user-specified + // filter to decide whether the test should run, then records the + // result in each TestCase and TestInfo object. + // If shard_tests == HONOR_SHARDING_PROTOCOL, further filters tests + // based on sharding variables in the environment. + // Returns the number of tests that should run. + int FilterTests(ReactionToSharding shard_tests); + + // Prints the names of the tests matching the user-specified filter flag. + void ListTestsMatchingFilter(); + + const TestCase* current_test_case() const { return current_test_case_; } + TestInfo* current_test_info() { return current_test_info_; } + const TestInfo* current_test_info() const { return current_test_info_; } + + // Returns the vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector& environments() { return environments_; } + + // Getters for the per-thread Google Test trace stack. + std::vector& gtest_trace_stack() { + return *(gtest_trace_stack_.pointer()); + } + const std::vector& gtest_trace_stack() const { + return gtest_trace_stack_.get(); + } + +#if GTEST_HAS_DEATH_TEST + void InitDeathTestSubprocessControlInfo() { + internal_run_death_test_flag_.reset(ParseInternalRunDeathTestFlag()); + } + // Returns a pointer to the parsed --gtest_internal_run_death_test + // flag, or NULL if that flag was not specified. + // This information is useful only in a death test child process. + // Must not be called before a call to InitGoogleTest. + const InternalRunDeathTestFlag* internal_run_death_test_flag() const { + return internal_run_death_test_flag_.get(); + } + + // Returns a pointer to the current death test factory. + internal::DeathTestFactory* death_test_factory() { + return death_test_factory_.get(); + } + + void SuppressTestEventsIfInSubprocess(); + + friend class ReplaceDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + + // Initializes the event listener performing XML output as specified by + // UnitTestOptions. Must not be called before InitGoogleTest. + void ConfigureXmlOutput(); + +#if GTEST_CAN_STREAM_RESULTS_ + // Initializes the event listener for streaming test results to a socket. + // Must not be called before InitGoogleTest. + void ConfigureStreamingOutput(); +#endif + + // Performs initialization dependent upon flag values obtained in + // ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to + // ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest + // this function is also called from RunAllTests. Since this function can be + // called more than once, it has to be idempotent. + void PostFlagParsingInit(); + + // Gets the random seed used at the start of the current test iteration. + int random_seed() const { return random_seed_; } + + // Gets the random number generator. + internal::Random* random() { return &random_; } + + // Shuffles all test cases, and the tests within each test case, + // making sure that death tests are still run first. + void ShuffleTests(); + + // Restores the test cases and tests to their order before the first shuffle. + void UnshuffleTests(); + + // Returns the value of GTEST_FLAG(catch_exceptions) at the moment + // UnitTest::Run() starts. + bool catch_exceptions() const { return catch_exceptions_; } + + private: + friend class ::testing::UnitTest; + + // Used by UnitTest::Run() to capture the state of + // GTEST_FLAG(catch_exceptions) at the moment it starts. + void set_catch_exceptions(bool value) { catch_exceptions_ = value; } + + // The UnitTest object that owns this implementation object. + UnitTest* const parent_; + + // The working directory when the first TEST() or TEST_F() was + // executed. + internal::FilePath original_working_dir_; + + // The default test part result reporters. + DefaultGlobalTestPartResultReporter default_global_test_part_result_reporter_; + DefaultPerThreadTestPartResultReporter + default_per_thread_test_part_result_reporter_; + + // Points to (but doesn't own) the global test part result reporter. + TestPartResultReporterInterface* global_test_part_result_repoter_; + + // Protects read and write access to global_test_part_result_reporter_. + internal::Mutex global_test_part_result_reporter_mutex_; + + // Points to (but doesn't own) the per-thread test part result reporter. + internal::ThreadLocal + per_thread_test_part_result_reporter_; + + // The vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector environments_; + + // The vector of TestCases in their original order. It owns the + // elements in the vector. + std::vector test_cases_; + + // Provides a level of indirection for the test case list to allow + // easy shuffling and restoring the test case order. The i-th + // element of this vector is the index of the i-th test case in the + // shuffled order. + std::vector test_case_indices_; + +#if GTEST_HAS_PARAM_TEST + // ParameterizedTestRegistry object used to register value-parameterized + // tests. + internal::ParameterizedTestCaseRegistry parameterized_test_registry_; + + // Indicates whether RegisterParameterizedTests() has been called already. + bool parameterized_tests_registered_; +#endif // GTEST_HAS_PARAM_TEST + + // Index of the last death test case registered. Initially -1. + int last_death_test_case_; + + // This points to the TestCase for the currently running test. It + // changes as Google Test goes through one test case after another. + // When no test is running, this is set to NULL and Google Test + // stores assertion results in ad_hoc_test_result_. Initially NULL. + TestCase* current_test_case_; + + // This points to the TestInfo for the currently running test. It + // changes as Google Test goes through one test after another. When + // no test is running, this is set to NULL and Google Test stores + // assertion results in ad_hoc_test_result_. Initially NULL. + TestInfo* current_test_info_; + + // Normally, a user only writes assertions inside a TEST or TEST_F, + // or inside a function called by a TEST or TEST_F. Since Google + // Test keeps track of which test is current running, it can + // associate such an assertion with the test it belongs to. + // + // If an assertion is encountered when no TEST or TEST_F is running, + // Google Test attributes the assertion result to an imaginary "ad hoc" + // test, and records the result in ad_hoc_test_result_. + TestResult ad_hoc_test_result_; + + // The list of event listeners that can be used to track events inside + // Google Test. + TestEventListeners listeners_; + + // The OS stack trace getter. Will be deleted when the UnitTest + // object is destructed. By default, an OsStackTraceGetter is used, + // but the user can set this field to use a custom getter if that is + // desired. + OsStackTraceGetterInterface* os_stack_trace_getter_; + + // True iff PostFlagParsingInit() has been called. + bool post_flag_parse_init_performed_; + + // The random number seed used at the beginning of the test run. + int random_seed_; + + // Our random number generator. + internal::Random random_; + + // How long the test took to run, in milliseconds. + TimeInMillis elapsed_time_; + +#if GTEST_HAS_DEATH_TEST + // The decomposed components of the gtest_internal_run_death_test flag, + // parsed when RUN_ALL_TESTS is called. + internal::scoped_ptr internal_run_death_test_flag_; + internal::scoped_ptr death_test_factory_; +#endif // GTEST_HAS_DEATH_TEST + + // A per-thread stack of traces created by the SCOPED_TRACE() macro. + internal::ThreadLocal > gtest_trace_stack_; + + // The value of GTEST_FLAG(catch_exceptions) at the moment RunAllTests() + // starts. + bool catch_exceptions_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTestImpl); +}; // class UnitTestImpl + +// Convenience function for accessing the global UnitTest +// implementation object. +inline UnitTestImpl* GetUnitTestImpl() { + return UnitTest::GetInstance()->impl(); +} + +#if GTEST_USES_SIMPLE_RE + +// Internal helper functions for implementing the simple regular +// expression matcher. +GTEST_API_ bool IsInSet(char ch, const char* str); +GTEST_API_ bool IsAsciiDigit(char ch); +GTEST_API_ bool IsAsciiPunct(char ch); +GTEST_API_ bool IsRepeat(char ch); +GTEST_API_ bool IsAsciiWhiteSpace(char ch); +GTEST_API_ bool IsAsciiWordChar(char ch); +GTEST_API_ bool IsValidEscape(char ch); +GTEST_API_ bool AtomMatchesChar(bool escaped, char pattern, char ch); +GTEST_API_ bool ValidateRegex(const char* regex); +GTEST_API_ bool MatchRegexAtHead(const char* regex, const char* str); +GTEST_API_ bool MatchRepetitionAndRegexAtHead( + bool escaped, char ch, char repeat, const char* regex, const char* str); +GTEST_API_ bool MatchRegexAnywhere(const char* regex, const char* str); + +#endif // GTEST_USES_SIMPLE_RE + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, char** argv); +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv); + +#if GTEST_HAS_DEATH_TEST + +// Returns the message describing the last system error, regardless of the +// platform. +GTEST_API_ String GetLastErrnoDescription(); + +# if GTEST_OS_WINDOWS +// Provides leak-safe Windows kernel handle ownership. +class AutoHandle { + public: + AutoHandle() : handle_(INVALID_HANDLE_VALUE) {} + explicit AutoHandle(HANDLE handle) : handle_(handle) {} + + ~AutoHandle() { Reset(); } + + HANDLE Get() const { return handle_; } + void Reset() { Reset(INVALID_HANDLE_VALUE); } + void Reset(HANDLE handle) { + if (handle != handle_) { + if (handle_ != INVALID_HANDLE_VALUE) + ::CloseHandle(handle_); + handle_ = handle; + } + } + + private: + HANDLE handle_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AutoHandle); +}; +# endif // GTEST_OS_WINDOWS + +// Attempts to parse a string into a positive integer pointed to by the +// number parameter. Returns true if that is possible. +// GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we can use +// it here. +template +bool ParseNaturalNumber(const ::std::string& str, Integer* number) { + // Fail fast if the given string does not begin with a digit; + // this bypasses strtoXXX's "optional leading whitespace and plus + // or minus sign" semantics, which are undesirable here. + if (str.empty() || !IsDigit(str[0])) { + return false; + } + errno = 0; + + char* end; + // BiggestConvertible is the largest integer type that system-provided + // string-to-number conversion routines can return. + +# if GTEST_OS_WINDOWS && !defined(__GNUC__) + + // MSVC and C++ Builder define __int64 instead of the standard long long. + typedef unsigned __int64 BiggestConvertible; + const BiggestConvertible parsed = _strtoui64(str.c_str(), &end, 10); + +# else + + typedef unsigned long long BiggestConvertible; // NOLINT + const BiggestConvertible parsed = strtoull(str.c_str(), &end, 10); + +# endif // GTEST_OS_WINDOWS && !defined(__GNUC__) + + const bool parse_success = *end == '\0' && errno == 0; + + // TODO(vladl@google.com): Convert this to compile time assertion when it is + // available. + GTEST_CHECK_(sizeof(Integer) <= sizeof(parsed)); + + const Integer result = static_cast(parsed); + if (parse_success && static_cast(result) == parsed) { + *number = result; + return true; + } + return false; +} +#endif // GTEST_HAS_DEATH_TEST + +// TestResult contains some private methods that should be hidden from +// Google Test user but are required for testing. This class allow our tests +// to access them. +// +// This class is supplied only for the purpose of testing Google Test's own +// constructs. Do not use it in user tests, either directly or indirectly. +class TestResultAccessor { + public: + static void RecordProperty(TestResult* test_result, + const TestProperty& property) { + test_result->RecordProperty(property); + } + + static void ClearTestPartResults(TestResult* test_result) { + test_result->ClearTestPartResults(); + } + + static const std::vector& test_part_results( + const TestResult& test_result) { + return test_result.test_part_results(); + } +}; + +} // namespace internal +} // namespace testing + +#endif // GTEST_SRC_GTEST_INTERNAL_INL_H_ +#undef GTEST_IMPLEMENTATION_ + +#if GTEST_OS_WINDOWS +# define vsnprintf _vsnprintf +#endif // GTEST_OS_WINDOWS + +namespace testing { + +using internal::CountIf; +using internal::ForEach; +using internal::GetElementOr; +using internal::Shuffle; + +// Constants. + +// A test whose test case name or test name matches this filter is +// disabled and not run. +static const char kDisableTestFilter[] = "DISABLED_*:*/DISABLED_*"; + +// A test case whose name matches this filter is considered a death +// test case and will be run before test cases whose name doesn't +// match this filter. +static const char kDeathTestCaseFilter[] = "*DeathTest:*DeathTest/*"; + +// A test filter that matches everything. +static const char kUniversalFilter[] = "*"; + +// The default output file for XML output. +static const char kDefaultOutputFile[] = "test_detail.xml"; + +// The environment variable name for the test shard index. +static const char kTestShardIndex[] = "GTEST_SHARD_INDEX"; +// The environment variable name for the total number of test shards. +static const char kTestTotalShards[] = "GTEST_TOTAL_SHARDS"; +// The environment variable name for the test shard status file. +static const char kTestShardStatusFile[] = "GTEST_SHARD_STATUS_FILE"; + +namespace internal { + +// The text used in failure messages to indicate the start of the +// stack trace. +const char kStackTraceMarker[] = "\nStack trace:\n"; + +// g_help_flag is true iff the --help flag or an equivalent form is +// specified on the command line. +bool g_help_flag = false; + +} // namespace internal + +GTEST_DEFINE_bool_( + also_run_disabled_tests, + internal::BoolFromGTestEnv("also_run_disabled_tests", false), + "Run disabled tests too, in addition to the tests normally being run."); + +GTEST_DEFINE_bool_( + break_on_failure, + internal::BoolFromGTestEnv("break_on_failure", false), + "True iff a failed assertion should be a debugger break-point."); + +GTEST_DEFINE_bool_( + catch_exceptions, + internal::BoolFromGTestEnv("catch_exceptions", true), + "True iff " GTEST_NAME_ + " should catch exceptions and treat them as test failures."); + +GTEST_DEFINE_string_( + color, + internal::StringFromGTestEnv("color", "auto"), + "Whether to use colors in the output. Valid values: yes, no, " + "and auto. 'auto' means to use colors if the output is " + "being sent to a terminal and the TERM environment variable " + "is set to xterm, xterm-color, xterm-256color, linux or cygwin."); + +GTEST_DEFINE_string_( + filter, + internal::StringFromGTestEnv("filter", kUniversalFilter), + "A colon-separated list of glob (not regex) patterns " + "for filtering the tests to run, optionally followed by a " + "'-' and a : separated list of negative patterns (tests to " + "exclude). A test is run if it matches one of the positive " + "patterns and does not match any of the negative patterns."); + +GTEST_DEFINE_bool_(list_tests, false, + "List all tests without running them."); + +GTEST_DEFINE_string_( + output, + internal::StringFromGTestEnv("output", ""), + "A format (currently must be \"xml\"), optionally followed " + "by a colon and an output file name or directory. A directory " + "is indicated by a trailing pathname separator. " + "Examples: \"xml:filename.xml\", \"xml::directoryname/\". " + "If a directory is specified, output files will be created " + "within that directory, with file-names based on the test " + "executable's name and, if necessary, made unique by adding " + "digits."); + +GTEST_DEFINE_bool_( + print_time, + internal::BoolFromGTestEnv("print_time", true), + "True iff " GTEST_NAME_ + " should display elapsed time in text output."); + +GTEST_DEFINE_int32_( + random_seed, + internal::Int32FromGTestEnv("random_seed", 0), + "Random number seed to use when shuffling test orders. Must be in range " + "[1, 99999], or 0 to use a seed based on the current time."); + +GTEST_DEFINE_int32_( + repeat, + internal::Int32FromGTestEnv("repeat", 1), + "How many times to repeat each test. Specify a negative number " + "for repeating forever. Useful for shaking out flaky tests."); + +GTEST_DEFINE_bool_( + show_internal_stack_frames, false, + "True iff " GTEST_NAME_ " should include internal stack frames when " + "printing test failure stack traces."); + +GTEST_DEFINE_bool_( + shuffle, + internal::BoolFromGTestEnv("shuffle", false), + "True iff " GTEST_NAME_ + " should randomize tests' order on every run."); + +GTEST_DEFINE_int32_( + stack_trace_depth, + internal::Int32FromGTestEnv("stack_trace_depth", kMaxStackTraceDepth), + "The maximum number of stack frames to print when an " + "assertion fails. The valid range is 0 through 100, inclusive."); + +GTEST_DEFINE_string_( + stream_result_to, + internal::StringFromGTestEnv("stream_result_to", ""), + "This flag specifies the host name and the port number on which to stream " + "test results. Example: \"localhost:555\". The flag is effective only on " + "Linux."); + +GTEST_DEFINE_bool_( + throw_on_failure, + internal::BoolFromGTestEnv("throw_on_failure", false), + "When this flag is specified, a failed assertion will throw an exception " + "if exceptions are enabled or exit the program with a non-zero code " + "otherwise."); + +namespace internal { + +// Generates a random number from [0, range), using a Linear +// Congruential Generator (LCG). Crashes if 'range' is 0 or greater +// than kMaxRange. +UInt32 Random::Generate(UInt32 range) { + // These constants are the same as are used in glibc's rand(3). + state_ = (1103515245U*state_ + 12345U) % kMaxRange; + + GTEST_CHECK_(range > 0) + << "Cannot generate a number in the range [0, 0)."; + GTEST_CHECK_(range <= kMaxRange) + << "Generation of a number in [0, " << range << ") was requested, " + << "but this can only generate numbers in [0, " << kMaxRange << ")."; + + // Converting via modulus introduces a bit of downward bias, but + // it's simple, and a linear congruential generator isn't too good + // to begin with. + return state_ % range; +} + +// GTestIsInitialized() returns true iff the user has initialized +// Google Test. Useful for catching the user mistake of not initializing +// Google Test before calling RUN_ALL_TESTS(). +// +// A user must call testing::InitGoogleTest() to initialize Google +// Test. g_init_gtest_count is set to the number of times +// InitGoogleTest() has been called. We don't protect this variable +// under a mutex as it is only accessed in the main thread. +int g_init_gtest_count = 0; +static bool GTestIsInitialized() { return g_init_gtest_count != 0; } + +// Iterates over a vector of TestCases, keeping a running sum of the +// results of calling a given int-returning method on each. +// Returns the sum. +static int SumOverTestCaseList(const std::vector& case_list, + int (TestCase::*method)() const) { + int sum = 0; + for (size_t i = 0; i < case_list.size(); i++) { + sum += (case_list[i]->*method)(); + } + return sum; +} + +// Returns true iff the test case passed. +static bool TestCasePassed(const TestCase* test_case) { + return test_case->should_run() && test_case->Passed(); +} + +// Returns true iff the test case failed. +static bool TestCaseFailed(const TestCase* test_case) { + return test_case->should_run() && test_case->Failed(); +} + +// Returns true iff test_case contains at least one test that should +// run. +static bool ShouldRunTestCase(const TestCase* test_case) { + return test_case->should_run(); +} + +// AssertHelper constructor. +AssertHelper::AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message) + : data_(new AssertHelperData(type, file, line, message)) { +} + +AssertHelper::~AssertHelper() { + delete data_; +} + +// Message assignment, for assertion streaming support. +void AssertHelper::operator=(const Message& message) const { + UnitTest::GetInstance()-> + AddTestPartResult(data_->type, data_->file, data_->line, + AppendUserMessage(data_->message, message), + UnitTest::GetInstance()->impl() + ->CurrentOsStackTraceExceptTop(1) + // Skips the stack frame for this function itself. + ); // NOLINT +} + +// Mutex for linked pointers. +GTEST_DEFINE_STATIC_MUTEX_(g_linked_ptr_mutex); + +// Application pathname gotten in InitGoogleTest. +String g_executable_path; + +// Returns the current application's name, removing directory path if that +// is present. +FilePath GetCurrentExecutableName() { + FilePath result; + +#if GTEST_OS_WINDOWS + result.Set(FilePath(g_executable_path).RemoveExtension("exe")); +#else + result.Set(FilePath(g_executable_path)); +#endif // GTEST_OS_WINDOWS + + return result.RemoveDirectoryName(); +} + +// Functions for processing the gtest_output flag. + +// Returns the output format, or "" for normal printed output. +String UnitTestOptions::GetOutputFormat() { + const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); + if (gtest_output_flag == NULL) return String(""); + + const char* const colon = strchr(gtest_output_flag, ':'); + return (colon == NULL) ? + String(gtest_output_flag) : + String(gtest_output_flag, colon - gtest_output_flag); +} + +// Returns the name of the requested output file, or the default if none +// was explicitly specified. +String UnitTestOptions::GetAbsolutePathToOutputFile() { + const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); + if (gtest_output_flag == NULL) + return String(""); + + const char* const colon = strchr(gtest_output_flag, ':'); + if (colon == NULL) + return String(internal::FilePath::ConcatPaths( + internal::FilePath( + UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(kDefaultOutputFile)).ToString() ); + + internal::FilePath output_name(colon + 1); + if (!output_name.IsAbsolutePath()) + // TODO(wan@google.com): on Windows \some\path is not an absolute + // path (as its meaning depends on the current drive), yet the + // following logic for turning it into an absolute path is wrong. + // Fix it. + output_name = internal::FilePath::ConcatPaths( + internal::FilePath(UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(colon + 1)); + + if (!output_name.IsDirectory()) + return output_name.ToString(); + + internal::FilePath result(internal::FilePath::GenerateUniqueFileName( + output_name, internal::GetCurrentExecutableName(), + GetOutputFormat().c_str())); + return result.ToString(); +} + +// Returns true iff the wildcard pattern matches the string. The +// first ':' or '\0' character in pattern marks the end of it. +// +// This recursive algorithm isn't very efficient, but is clear and +// works well enough for matching test names, which are short. +bool UnitTestOptions::PatternMatchesString(const char *pattern, + const char *str) { + switch (*pattern) { + case '\0': + case ':': // Either ':' or '\0' marks the end of the pattern. + return *str == '\0'; + case '?': // Matches any single character. + return *str != '\0' && PatternMatchesString(pattern + 1, str + 1); + case '*': // Matches any string (possibly empty) of characters. + return (*str != '\0' && PatternMatchesString(pattern, str + 1)) || + PatternMatchesString(pattern + 1, str); + default: // Non-special character. Matches itself. + return *pattern == *str && + PatternMatchesString(pattern + 1, str + 1); + } +} + +bool UnitTestOptions::MatchesFilter(const String& name, const char* filter) { + const char *cur_pattern = filter; + for (;;) { + if (PatternMatchesString(cur_pattern, name.c_str())) { + return true; + } + + // Finds the next pattern in the filter. + cur_pattern = strchr(cur_pattern, ':'); + + // Returns if no more pattern can be found. + if (cur_pattern == NULL) { + return false; + } + + // Skips the pattern separater (the ':' character). + cur_pattern++; + } +} + +// TODO(keithray): move String function implementations to gtest-string.cc. + +// Returns true iff the user-specified filter matches the test case +// name and the test name. +bool UnitTestOptions::FilterMatchesTest(const String &test_case_name, + const String &test_name) { + const String& full_name = String::Format("%s.%s", + test_case_name.c_str(), + test_name.c_str()); + + // Split --gtest_filter at '-', if there is one, to separate into + // positive filter and negative filter portions + const char* const p = GTEST_FLAG(filter).c_str(); + const char* const dash = strchr(p, '-'); + String positive; + String negative; + if (dash == NULL) { + positive = GTEST_FLAG(filter).c_str(); // Whole string is a positive filter + negative = String(""); + } else { + positive = String(p, dash - p); // Everything up to the dash + negative = String(dash+1); // Everything after the dash + if (positive.empty()) { + // Treat '-test1' as the same as '*-test1' + positive = kUniversalFilter; + } + } + + // A filter is a colon-separated list of patterns. It matches a + // test if any pattern in it matches the test. + return (MatchesFilter(full_name, positive.c_str()) && + !MatchesFilter(full_name, negative.c_str())); +} + +#if GTEST_HAS_SEH +// Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the +// given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. +// This function is useful as an __except condition. +int UnitTestOptions::GTestShouldProcessSEH(DWORD exception_code) { + // Google Test should handle a SEH exception if: + // 1. the user wants it to, AND + // 2. this is not a breakpoint exception, AND + // 3. this is not a C++ exception (VC++ implements them via SEH, + // apparently). + // + // SEH exception code for C++ exceptions. + // (see http://support.microsoft.com/kb/185294 for more information). + const DWORD kCxxExceptionCode = 0xe06d7363; + + bool should_handle = true; + + if (!GTEST_FLAG(catch_exceptions)) + should_handle = false; + else if (exception_code == EXCEPTION_BREAKPOINT) + should_handle = false; + else if (exception_code == kCxxExceptionCode) + should_handle = false; + + return should_handle ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH; +} +#endif // GTEST_HAS_SEH + +} // namespace internal + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. Intercepts only failures from the current thread. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + TestPartResultArray* result) + : intercept_mode_(INTERCEPT_ONLY_CURRENT_THREAD), + result_(result) { + Init(); +} + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + InterceptMode intercept_mode, TestPartResultArray* result) + : intercept_mode_(intercept_mode), + result_(result) { + Init(); +} + +void ScopedFakeTestPartResultReporter::Init() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + old_reporter_ = impl->GetGlobalTestPartResultReporter(); + impl->SetGlobalTestPartResultReporter(this); + } else { + old_reporter_ = impl->GetTestPartResultReporterForCurrentThread(); + impl->SetTestPartResultReporterForCurrentThread(this); + } +} + +// The d'tor restores the test part result reporter used by Google Test +// before. +ScopedFakeTestPartResultReporter::~ScopedFakeTestPartResultReporter() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + impl->SetGlobalTestPartResultReporter(old_reporter_); + } else { + impl->SetTestPartResultReporterForCurrentThread(old_reporter_); + } +} + +// Increments the test part result count and remembers the result. +// This method is from the TestPartResultReporterInterface interface. +void ScopedFakeTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + result_->Append(result); +} + +namespace internal { + +// Returns the type ID of ::testing::Test. We should always call this +// instead of GetTypeId< ::testing::Test>() to get the type ID of +// testing::Test. This is to work around a suspected linker bug when +// using Google Test as a framework on Mac OS X. The bug causes +// GetTypeId< ::testing::Test>() to return different values depending +// on whether the call is from the Google Test framework itself or +// from user test code. GetTestTypeId() is guaranteed to always +// return the same value, as it always calls GetTypeId<>() from the +// gtest.cc, which is within the Google Test framework. +TypeId GetTestTypeId() { + return GetTypeId(); +} + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +extern const TypeId kTestTypeIdInGoogleTest = GetTestTypeId(); + +// This predicate-formatter checks that 'results' contains a test part +// failure of the given type and that the failure message contains the +// given substring. +AssertionResult HasOneFailure(const char* /* results_expr */, + const char* /* type_expr */, + const char* /* substr_expr */, + const TestPartResultArray& results, + TestPartResult::Type type, + const string& substr) { + const String expected(type == TestPartResult::kFatalFailure ? + "1 fatal failure" : + "1 non-fatal failure"); + Message msg; + if (results.size() != 1) { + msg << "Expected: " << expected << "\n" + << " Actual: " << results.size() << " failures"; + for (int i = 0; i < results.size(); i++) { + msg << "\n" << results.GetTestPartResult(i); + } + return AssertionFailure() << msg; + } + + const TestPartResult& r = results.GetTestPartResult(0); + if (r.type() != type) { + return AssertionFailure() << "Expected: " << expected << "\n" + << " Actual:\n" + << r; + } + + if (strstr(r.message(), substr.c_str()) == NULL) { + return AssertionFailure() << "Expected: " << expected << " containing \"" + << substr << "\"\n" + << " Actual:\n" + << r; + } + + return AssertionSuccess(); +} + +// The constructor of SingleFailureChecker remembers where to look up +// test part results, what type of failure we expect, and what +// substring the failure message should contain. +SingleFailureChecker:: SingleFailureChecker( + const TestPartResultArray* results, + TestPartResult::Type type, + const string& substr) + : results_(results), + type_(type), + substr_(substr) {} + +// The destructor of SingleFailureChecker verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +SingleFailureChecker::~SingleFailureChecker() { + EXPECT_PRED_FORMAT3(HasOneFailure, *results_, type_, substr_); +} + +DefaultGlobalTestPartResultReporter::DefaultGlobalTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultGlobalTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->current_test_result()->AddTestPartResult(result); + unit_test_->listeners()->repeater()->OnTestPartResult(result); +} + +DefaultPerThreadTestPartResultReporter::DefaultPerThreadTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultPerThreadTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->GetGlobalTestPartResultReporter()->ReportTestPartResult(result); +} + +// Returns the global test part result reporter. +TestPartResultReporterInterface* +UnitTestImpl::GetGlobalTestPartResultReporter() { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + return global_test_part_result_repoter_; +} + +// Sets the global test part result reporter. +void UnitTestImpl::SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter) { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + global_test_part_result_repoter_ = reporter; +} + +// Returns the test part result reporter for the current thread. +TestPartResultReporterInterface* +UnitTestImpl::GetTestPartResultReporterForCurrentThread() { + return per_thread_test_part_result_reporter_.get(); +} + +// Sets the test part result reporter for the current thread. +void UnitTestImpl::SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter) { + per_thread_test_part_result_reporter_.set(reporter); +} + +// Gets the number of successful test cases. +int UnitTestImpl::successful_test_case_count() const { + return CountIf(test_cases_, TestCasePassed); +} + +// Gets the number of failed test cases. +int UnitTestImpl::failed_test_case_count() const { + return CountIf(test_cases_, TestCaseFailed); +} + +// Gets the number of all test cases. +int UnitTestImpl::total_test_case_count() const { + return static_cast(test_cases_.size()); +} + +// Gets the number of all test cases that contain at least one test +// that should run. +int UnitTestImpl::test_case_to_run_count() const { + return CountIf(test_cases_, ShouldRunTestCase); +} + +// Gets the number of successful tests. +int UnitTestImpl::successful_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::successful_test_count); +} + +// Gets the number of failed tests. +int UnitTestImpl::failed_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::failed_test_count); +} + +// Gets the number of disabled tests. +int UnitTestImpl::disabled_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::disabled_test_count); +} + +// Gets the number of all tests. +int UnitTestImpl::total_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::total_test_count); +} + +// Gets the number of tests that should run. +int UnitTestImpl::test_to_run_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::test_to_run_count); +} + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// CurrentOsStackTraceExceptTop(1), Foo() will be included in the +// trace but Bar() and CurrentOsStackTraceExceptTop() won't. +String UnitTestImpl::CurrentOsStackTraceExceptTop(int skip_count) { + (void)skip_count; + return String(""); +} + +// Returns the current time in milliseconds. +TimeInMillis GetTimeInMillis() { +#if GTEST_OS_WINDOWS_MOBILE || defined(__BORLANDC__) + // Difference between 1970-01-01 and 1601-01-01 in milliseconds. + // http://analogous.blogspot.com/2005/04/epoch.html + const TimeInMillis kJavaEpochToWinFileTimeDelta = + static_cast(116444736UL) * 100000UL; + const DWORD kTenthMicrosInMilliSecond = 10000; + + SYSTEMTIME now_systime; + FILETIME now_filetime; + ULARGE_INTEGER now_int64; + // TODO(kenton@google.com): Shouldn't this just use + // GetSystemTimeAsFileTime()? + GetSystemTime(&now_systime); + if (SystemTimeToFileTime(&now_systime, &now_filetime)) { + now_int64.LowPart = now_filetime.dwLowDateTime; + now_int64.HighPart = now_filetime.dwHighDateTime; + now_int64.QuadPart = (now_int64.QuadPart / kTenthMicrosInMilliSecond) - + kJavaEpochToWinFileTimeDelta; + return now_int64.QuadPart; + } + return 0; +#elif GTEST_OS_WINDOWS && !GTEST_HAS_GETTIMEOFDAY_ + __timeb64 now; + +# ifdef _MSC_VER + + // MSVC 8 deprecates _ftime64(), so we want to suppress warning 4996 + // (deprecated function) there. + // TODO(kenton@google.com): Use GetTickCount()? Or use + // SystemTimeToFileTime() +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4996) // Temporarily disables warning 4996. + _ftime64(&now); +# pragma warning(pop) // Restores the warning state. +# else + + _ftime64(&now); + +# endif // _MSC_VER + + return static_cast(now.time) * 1000 + now.millitm; +#elif GTEST_HAS_GETTIMEOFDAY_ + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000 + now.tv_usec / 1000; +#else +# error "Don't know how to get the current time on your system." +#endif +} + +// Utilities + +// class String + +// Returns the input enclosed in double quotes if it's not NULL; +// otherwise returns "(null)". For example, "\"Hello\"" is returned +// for input "Hello". +// +// This is useful for printing a C string in the syntax of a literal. +// +// Known issue: escape sequences are not handled yet. +String String::ShowCStringQuoted(const char* c_str) { + return c_str ? String::Format("\"%s\"", c_str) : String("(null)"); +} + +// Copies at most length characters from str into a newly-allocated +// piece of memory of size length+1. The memory is allocated with new[]. +// A terminating null byte is written to the memory, and a pointer to it +// is returned. If str is NULL, NULL is returned. +static char* CloneString(const char* str, size_t length) { + if (str == NULL) { + return NULL; + } else { + char* const clone = new char[length + 1]; + posix::StrNCpy(clone, str, length); + clone[length] = '\0'; + return clone; + } +} + +// Clones a 0-terminated C string, allocating memory using new. The +// caller is responsible for deleting[] the return value. Returns the +// cloned string, or NULL if the input is NULL. +const char * String::CloneCString(const char* c_str) { + return (c_str == NULL) ? + NULL : CloneString(c_str, strlen(c_str)); +} + +#if GTEST_OS_WINDOWS_MOBILE +// Creates a UTF-16 wide string from the given ANSI string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the wide string, or NULL if the +// input is NULL. +LPCWSTR String::AnsiToUtf16(const char* ansi) { + if (!ansi) return NULL; + const int length = strlen(ansi); + const int unicode_length = + MultiByteToWideChar(CP_ACP, 0, ansi, length, + NULL, 0); + WCHAR* unicode = new WCHAR[unicode_length + 1]; + MultiByteToWideChar(CP_ACP, 0, ansi, length, + unicode, unicode_length); + unicode[unicode_length] = 0; + return unicode; +} + +// Creates an ANSI string from the given wide string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the ANSI string, or NULL if the +// input is NULL. +const char* String::Utf16ToAnsi(LPCWSTR utf16_str) { + if (!utf16_str) return NULL; + const int ansi_length = + WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, + NULL, 0, NULL, NULL); + char* ansi = new char[ansi_length + 1]; + WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, + ansi, ansi_length, NULL, NULL); + ansi[ansi_length] = 0; + return ansi; +} + +#endif // GTEST_OS_WINDOWS_MOBILE + +// Compares two C strings. Returns true iff they have the same content. +// +// Unlike strcmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CStringEquals(const char * lhs, const char * rhs) { + if ( lhs == NULL ) return rhs == NULL; + + if ( rhs == NULL ) return false; + + return strcmp(lhs, rhs) == 0; +} + +#if GTEST_HAS_STD_WSTRING || GTEST_HAS_GLOBAL_WSTRING + +// Converts an array of wide chars to a narrow string using the UTF-8 +// encoding, and streams the result to the given Message object. +static void StreamWideCharsToMessage(const wchar_t* wstr, size_t length, + Message* msg) { + // TODO(wan): consider allowing a testing::String object to + // contain '\0'. This will make it behave more like std::string, + // and will allow ToUtf8String() to return the correct encoding + // for '\0' s.t. we can get rid of the conditional here (and in + // several other places). + for (size_t i = 0; i != length; ) { // NOLINT + if (wstr[i] != L'\0') { + *msg << WideStringToUtf8(wstr + i, static_cast(length - i)); + while (i != length && wstr[i] != L'\0') + i++; + } else { + *msg << '\0'; + i++; + } + } +} + +#endif // GTEST_HAS_STD_WSTRING || GTEST_HAS_GLOBAL_WSTRING + +} // namespace internal + +#if GTEST_HAS_STD_WSTRING +// Converts the given wide string to a narrow string using the UTF-8 +// encoding, and streams the result to this Message object. +Message& Message::operator <<(const ::std::wstring& wstr) { + internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); + return *this; +} +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_WSTRING +// Converts the given wide string to a narrow string using the UTF-8 +// encoding, and streams the result to this Message object. +Message& Message::operator <<(const ::wstring& wstr) { + internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); + return *this; +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +// AssertionResult constructors. +// Used in EXPECT_TRUE/FALSE(assertion_result). +AssertionResult::AssertionResult(const AssertionResult& other) + : success_(other.success_), + message_(other.message_.get() != NULL ? + new ::std::string(*other.message_) : + static_cast< ::std::string*>(NULL)) { +} + +// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. +AssertionResult AssertionResult::operator!() const { + AssertionResult negation(!success_); + if (message_.get() != NULL) + negation << *message_; + return negation; +} + +// Makes a successful assertion result. +AssertionResult AssertionSuccess() { + return AssertionResult(true); +} + +// Makes a failed assertion result. +AssertionResult AssertionFailure() { + return AssertionResult(false); +} + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << message. +AssertionResult AssertionFailure(const Message& message) { + return AssertionFailure() << message; +} + +namespace internal { + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// expected_expression: "foo" +// actual_expression: "bar" +// expected_value: "5" +// actual_value: "6" +// +// The ignoring_case parameter is true iff the assertion is a +// *_STRCASEEQ*. When it's true, the string " (ignoring case)" will +// be inserted into the message. +AssertionResult EqFailure(const char* expected_expression, + const char* actual_expression, + const String& expected_value, + const String& actual_value, + bool ignoring_case) { + Message msg; + msg << "Value of: " << actual_expression; + if (actual_value != actual_expression) { + msg << "\n Actual: " << actual_value; + } + + msg << "\nExpected: " << expected_expression; + if (ignoring_case) { + msg << " (ignoring case)"; + } + if (expected_value != expected_expression) { + msg << "\nWhich is: " << expected_value; + } + + return AssertionFailure() << msg; +} + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +String GetBoolAssertionFailureMessage(const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value) { + const char* actual_message = assertion_result.message(); + Message msg; + msg << "Value of: " << expression_text + << "\n Actual: " << actual_predicate_value; + if (actual_message[0] != '\0') + msg << " (" << actual_message << ")"; + msg << "\nExpected: " << expected_predicate_value; + return msg.GetString(); +} + +// Helper function for implementing ASSERT_NEAR. +AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error) { + const double diff = fabs(val1 - val2); + if (diff <= abs_error) return AssertionSuccess(); + + // TODO(wan): do not print the value of an expression if it's + // already a literal. + return AssertionFailure() + << "The difference between " << expr1 << " and " << expr2 + << " is " << diff << ", which exceeds " << abs_error_expr << ", where\n" + << expr1 << " evaluates to " << val1 << ",\n" + << expr2 << " evaluates to " << val2 << ", and\n" + << abs_error_expr << " evaluates to " << abs_error << "."; +} + + +// Helper template for implementing FloatLE() and DoubleLE(). +template +AssertionResult FloatingPointLE(const char* expr1, + const char* expr2, + RawType val1, + RawType val2) { + // Returns success if val1 is less than val2, + if (val1 < val2) { + return AssertionSuccess(); + } + + // or if val1 is almost equal to val2. + const FloatingPoint lhs(val1), rhs(val2); + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + // Note that the above two checks will both fail if either val1 or + // val2 is NaN, as the IEEE floating-point standard requires that + // any predicate involving a NaN must return false. + + ::std::stringstream val1_ss; + val1_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val1; + + ::std::stringstream val2_ss; + val2_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val2; + + return AssertionFailure() + << "Expected: (" << expr1 << ") <= (" << expr2 << ")\n" + << " Actual: " << StringStreamToString(&val1_ss) << " vs " + << StringStreamToString(&val2_ss); +} + +} // namespace internal + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +namespace internal { + +// The helper function for {ASSERT|EXPECT}_EQ with int or enum +// arguments. +AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual) { + if (expected == actual) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + FormatForComparisonFailureMessage(expected, actual), + FormatForComparisonFailureMessage(actual, expected), + false); +} + +// A macro for implementing the helper functions needed to implement +// ASSERT_?? and EXPECT_?? with integer or enum arguments. It is here +// just to avoid copy-and-paste of similar code. +#define GTEST_IMPL_CMP_HELPER_(op_name, op)\ +AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ + BiggestInt val1, BiggestInt val2) {\ + if (val1 op val2) {\ + return AssertionSuccess();\ + } else {\ + return AssertionFailure() \ + << "Expected: (" << expr1 << ") " #op " (" << expr2\ + << "), actual: " << FormatForComparisonFailureMessage(val1, val2)\ + << " vs " << FormatForComparisonFailureMessage(val2, val1);\ + }\ +} + +// Implements the helper function for {ASSERT|EXPECT}_NE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(NE, !=) +// Implements the helper function for {ASSERT|EXPECT}_LE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(LE, <=) +// Implements the helper function for {ASSERT|EXPECT}_LT with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(LT, < ) +// Implements the helper function for {ASSERT|EXPECT}_GE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(GE, >=) +// Implements the helper function for {ASSERT|EXPECT}_GT with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(GT, > ) + +#undef GTEST_IMPL_CMP_HELPER_ + +// The helper function for {ASSERT|EXPECT}_STREQ. +AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual) { + if (String::CStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowCStringQuoted(expected), + String::ShowCStringQuoted(actual), + false); +} + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +AssertionResult CmpHelperSTRCASEEQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual) { + if (String::CaseInsensitiveCStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowCStringQuoted(expected), + String::ShowCStringQuoted(actual), + true); +} + +// The helper function for {ASSERT|EXPECT}_STRNE. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CaseInsensitiveCStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() + << "Expected: (" << s1_expression << ") != (" + << s2_expression << ") (ignoring case), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +} // namespace internal + +namespace { + +// Helper functions for implementing IsSubString() and IsNotSubstring(). + +// This group of overloaded functions return true iff needle is a +// substring of haystack. NULL is considered a substring of itself +// only. + +bool IsSubstringPred(const char* needle, const char* haystack) { + if (needle == NULL || haystack == NULL) + return needle == haystack; + + return strstr(haystack, needle) != NULL; +} + +bool IsSubstringPred(const wchar_t* needle, const wchar_t* haystack) { + if (needle == NULL || haystack == NULL) + return needle == haystack; + + return wcsstr(haystack, needle) != NULL; +} + +// StringType here can be either ::std::string or ::std::wstring. +template +bool IsSubstringPred(const StringType& needle, + const StringType& haystack) { + return haystack.find(needle) != StringType::npos; +} + +// This function implements either IsSubstring() or IsNotSubstring(), +// depending on the value of the expected_to_be_substring parameter. +// StringType here can be const char*, const wchar_t*, ::std::string, +// or ::std::wstring. +template +AssertionResult IsSubstringImpl( + bool expected_to_be_substring, + const char* needle_expr, const char* haystack_expr, + const StringType& needle, const StringType& haystack) { + if (IsSubstringPred(needle, haystack) == expected_to_be_substring) + return AssertionSuccess(); + + const bool is_wide_string = sizeof(needle[0]) > 1; + const char* const begin_string_quote = is_wide_string ? "L\"" : "\""; + return AssertionFailure() + << "Value of: " << needle_expr << "\n" + << " Actual: " << begin_string_quote << needle << "\"\n" + << "Expected: " << (expected_to_be_substring ? "" : "not ") + << "a substring of " << haystack_expr << "\n" + << "Which is: " << begin_string_quote << haystack << "\""; +} + +} // namespace + +// IsSubstring() and IsNotSubstring() check whether needle is a +// substring of haystack (NULL is considered a substring of itself +// only), and return an appropriate error message when they fail. + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +#if GTEST_HAS_STD_WSTRING +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +#if GTEST_OS_WINDOWS + +namespace { + +// Helper function for IsHRESULT{SuccessFailure} predicates +AssertionResult HRESULTFailureHelper(const char* expr, + const char* expected, + long hr) { // NOLINT +# if GTEST_OS_WINDOWS_MOBILE + + // Windows CE doesn't support FormatMessage. + const char error_text[] = ""; + +# else + + // Looks up the human-readable system message for the HRESULT code + // and since we're not passing any params to FormatMessage, we don't + // want inserts expanded. + const DWORD kFlags = FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS; + const DWORD kBufSize = 4096; // String::Format can't exceed this length. + // Gets the system's human readable message string for this HRESULT. + char error_text[kBufSize] = { '\0' }; + DWORD message_length = ::FormatMessageA(kFlags, + 0, // no source, we're asking system + hr, // the error + 0, // no line width restrictions + error_text, // output buffer + kBufSize, // buf size + NULL); // no arguments for inserts + // Trims tailing white space (FormatMessage leaves a trailing cr-lf) + for (; message_length && IsSpace(error_text[message_length - 1]); + --message_length) { + error_text[message_length - 1] = '\0'; + } + +# endif // GTEST_OS_WINDOWS_MOBILE + + const String error_hex(String::Format("0x%08X ", hr)); + return ::testing::AssertionFailure() + << "Expected: " << expr << " " << expected << ".\n" + << " Actual: " << error_hex << error_text << "\n"; +} + +} // namespace + +AssertionResult IsHRESULTSuccess(const char* expr, long hr) { // NOLINT + if (SUCCEEDED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "succeeds", hr); +} + +AssertionResult IsHRESULTFailure(const char* expr, long hr) { // NOLINT + if (FAILED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "fails", hr); +} + +#endif // GTEST_OS_WINDOWS + +// Utility functions for encoding Unicode text (wide strings) in +// UTF-8. + +// A Unicode code-point can have upto 21 bits, and is encoded in UTF-8 +// like this: +// +// Code-point length Encoding +// 0 - 7 bits 0xxxxxxx +// 8 - 11 bits 110xxxxx 10xxxxxx +// 12 - 16 bits 1110xxxx 10xxxxxx 10xxxxxx +// 17 - 21 bits 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + +// The maximum code-point a one-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint1 = (static_cast(1) << 7) - 1; + +// The maximum code-point a two-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint2 = (static_cast(1) << (5 + 6)) - 1; + +// The maximum code-point a three-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint3 = (static_cast(1) << (4 + 2*6)) - 1; + +// The maximum code-point a four-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint4 = (static_cast(1) << (3 + 3*6)) - 1; + +// Chops off the n lowest bits from a bit pattern. Returns the n +// lowest bits. As a side effect, the original bit pattern will be +// shifted to the right by n bits. +inline UInt32 ChopLowBits(UInt32* bits, int n) { + const UInt32 low_bits = *bits & ((static_cast(1) << n) - 1); + *bits >>= n; + return low_bits; +} + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type UInt32 because wchar_t may not be +// wide enough to contain a code point. +// The output buffer str must containt at least 32 characters. +// The function returns the address of the output buffer. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. +char* CodePointToUtf8(UInt32 code_point, char* str) { + if (code_point <= kMaxCodePoint1) { + str[1] = '\0'; + str[0] = static_cast(code_point); // 0xxxxxxx + } else if (code_point <= kMaxCodePoint2) { + str[2] = '\0'; + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xC0 | code_point); // 110xxxxx + } else if (code_point <= kMaxCodePoint3) { + str[3] = '\0'; + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xE0 | code_point); // 1110xxxx + } else if (code_point <= kMaxCodePoint4) { + str[4] = '\0'; + str[3] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xF0 | code_point); // 11110xxx + } else { + // The longest string String::Format can produce when invoked + // with these parameters is 28 character long (not including + // the terminating nul character). We are asking for 32 character + // buffer just in case. This is also enough for strncpy to + // null-terminate the destination string. + posix::StrNCpy( + str, String::Format("(Invalid Unicode 0x%X)", code_point).c_str(), 32); + str[31] = '\0'; // Makes sure no change in the format to strncpy leaves + // the result unterminated. + } + return str; +} + +// The following two functions only make sense if the the system +// uses UTF-16 for wide string encoding. All supported systems +// with 16 bit wchar_t (Windows, Cygwin, Symbian OS) do use UTF-16. + +// Determines if the arguments constitute UTF-16 surrogate pair +// and thus should be combined into a single Unicode code point +// using CreateCodePointFromUtf16SurrogatePair. +inline bool IsUtf16SurrogatePair(wchar_t first, wchar_t second) { + return sizeof(wchar_t) == 2 && + (first & 0xFC00) == 0xD800 && (second & 0xFC00) == 0xDC00; +} + +// Creates a Unicode code point from UTF16 surrogate pair. +inline UInt32 CreateCodePointFromUtf16SurrogatePair(wchar_t first, + wchar_t second) { + const UInt32 mask = (1 << 10) - 1; + return (sizeof(wchar_t) == 2) ? + (((first & mask) << 10) | (second & mask)) + 0x10000 : + // This function should not be called when the condition is + // false, but we provide a sensible default in case it is. + static_cast(first); +} + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin, Symbian OS) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +String WideStringToUtf8(const wchar_t* str, int num_chars) { + if (num_chars == -1) + num_chars = static_cast(wcslen(str)); + + ::std::stringstream stream; + for (int i = 0; i < num_chars; ++i) { + UInt32 unicode_code_point; + + if (str[i] == L'\0') { + break; + } else if (i + 1 < num_chars && IsUtf16SurrogatePair(str[i], str[i + 1])) { + unicode_code_point = CreateCodePointFromUtf16SurrogatePair(str[i], + str[i + 1]); + i++; + } else { + unicode_code_point = static_cast(str[i]); + } + + char buffer[32]; // CodePointToUtf8 requires a buffer this big. + stream << CodePointToUtf8(unicode_code_point, buffer); + } + return StringStreamToString(&stream); +} + +// Converts a wide C string to a String using the UTF-8 encoding. +// NULL will be converted to "(null)". +String String::ShowWideCString(const wchar_t * wide_c_str) { + if (wide_c_str == NULL) return String("(null)"); + + return String(internal::WideStringToUtf8(wide_c_str, -1).c_str()); +} + +// Similar to ShowWideCString(), except that this function encloses +// the converted string in double quotes. +String String::ShowWideCStringQuoted(const wchar_t* wide_c_str) { + if (wide_c_str == NULL) return String("(null)"); + + return String::Format("L\"%s\"", + String::ShowWideCString(wide_c_str).c_str()); +} + +// Compares two wide C strings. Returns true iff they have the same +// content. +// +// Unlike wcscmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::WideCStringEquals(const wchar_t * lhs, const wchar_t * rhs) { + if (lhs == NULL) return rhs == NULL; + + if (rhs == NULL) return false; + + return wcscmp(lhs, rhs) == 0; +} + +// Helper function for *_STREQ on wide strings. +AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const wchar_t* expected, + const wchar_t* actual) { + if (String::WideCStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowWideCStringQuoted(expected), + String::ShowWideCStringQuoted(actual), + false); +} + +// Helper function for *_STRNE on wide strings. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2) { + if (!String::WideCStringEquals(s1, s2)) { + return AssertionSuccess(); + } + + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: " + << String::ShowWideCStringQuoted(s1) + << " vs " << String::ShowWideCStringQuoted(s2); +} + +// Compares two C strings, ignoring case. Returns true iff they have +// the same content. +// +// Unlike strcasecmp(), this function can handle NULL argument(s). A +// NULL C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CaseInsensitiveCStringEquals(const char * lhs, const char * rhs) { + if (lhs == NULL) + return rhs == NULL; + if (rhs == NULL) + return false; + return posix::StrCaseCmp(lhs, rhs) == 0; +} + + // Compares two wide C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike wcscasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL wide C string, + // including the empty string. + // NB: The implementations on different platforms slightly differ. + // On windows, this method uses _wcsicmp which compares according to LC_CTYPE + // environment variable. On GNU platform this method uses wcscasecmp + // which compares according to LC_CTYPE category of the current locale. + // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the + // current locale. +bool String::CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs) { + if (lhs == NULL) return rhs == NULL; + + if (rhs == NULL) return false; + +#if GTEST_OS_WINDOWS + return _wcsicmp(lhs, rhs) == 0; +#elif GTEST_OS_LINUX && !GTEST_OS_LINUX_ANDROID + return wcscasecmp(lhs, rhs) == 0; +#else + // Android, Mac OS X and Cygwin don't define wcscasecmp. + // Other unknown OSes may not define it either. + wint_t left, right; + do { + left = towlower(*lhs++); + right = towlower(*rhs++); + } while (left && left == right); + return left == right; +#endif // OS selector +} + +// Compares this with another String. +// Returns < 0 if this is less than rhs, 0 if this is equal to rhs, or > 0 +// if this is greater than rhs. +int String::Compare(const String & rhs) const { + const char* const lhs_c_str = c_str(); + const char* const rhs_c_str = rhs.c_str(); + + if (lhs_c_str == NULL) { + return rhs_c_str == NULL ? 0 : -1; // NULL < anything except NULL + } else if (rhs_c_str == NULL) { + return 1; + } + + const size_t shorter_str_len = + length() <= rhs.length() ? length() : rhs.length(); + for (size_t i = 0; i != shorter_str_len; i++) { + if (lhs_c_str[i] < rhs_c_str[i]) { + return -1; + } else if (lhs_c_str[i] > rhs_c_str[i]) { + return 1; + } + } + return (length() < rhs.length()) ? -1 : + (length() > rhs.length()) ? 1 : 0; +} + +// Returns true iff this String ends with the given suffix. *Any* +// String is considered to end with a NULL or empty suffix. +bool String::EndsWith(const char* suffix) const { + if (suffix == NULL || CStringEquals(suffix, "")) return true; + + if (c_str() == NULL) return false; + + const size_t this_len = strlen(c_str()); + const size_t suffix_len = strlen(suffix); + return (this_len >= suffix_len) && + CStringEquals(c_str() + this_len - suffix_len, suffix); +} + +// Returns true iff this String ends with the given suffix, ignoring case. +// Any String is considered to end with a NULL or empty suffix. +bool String::EndsWithCaseInsensitive(const char* suffix) const { + if (suffix == NULL || CStringEquals(suffix, "")) return true; + + if (c_str() == NULL) return false; + + const size_t this_len = strlen(c_str()); + const size_t suffix_len = strlen(suffix); + return (this_len >= suffix_len) && + CaseInsensitiveCStringEquals(c_str() + this_len - suffix_len, suffix); +} + +// Formats a list of arguments to a String, using the same format +// spec string as for printf. +// +// We do not use the StringPrintf class as it is not universally +// available. +// +// The result is limited to 4096 characters (including the tailing 0). +// If 4096 characters are not enough to format the input, or if +// there's an error, "" is +// returned. +String String::Format(const char * format, ...) { + va_list args; + va_start(args, format); + + char buffer[4096]; + const int kBufferSize = sizeof(buffer)/sizeof(buffer[0]); + + // MSVC 8 deprecates vsnprintf(), so we want to suppress warning + // 4996 (deprecated function) there. +#ifdef _MSC_VER // We are using MSVC. +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4996) // Temporarily disables warning 4996. + + const int size = vsnprintf(buffer, kBufferSize, format, args); + +# pragma warning(pop) // Restores the warning state. +#else // We are not using MSVC. + const int size = vsnprintf(buffer, kBufferSize, format, args); +#endif // _MSC_VER + va_end(args); + + // vsnprintf()'s behavior is not portable. When the buffer is not + // big enough, it returns a negative value in MSVC, and returns the + // needed buffer size on Linux. When there is an output error, it + // always returns a negative value. For simplicity, we lump the two + // error cases together. + if (size < 0 || size >= kBufferSize) { + return String(""); + } else { + return String(buffer, size); + } +} + +// Converts the buffer in a stringstream to a String, converting NUL +// bytes to "\\0" along the way. +String StringStreamToString(::std::stringstream* ss) { + const ::std::string& str = ss->str(); + const char* const start = str.c_str(); + const char* const end = start + str.length(); + + // We need to use a helper stringstream to do this transformation + // because String doesn't support push_back(). + ::std::stringstream helper; + for (const char* ch = start; ch != end; ++ch) { + if (*ch == '\0') { + helper << "\\0"; // Replaces NUL with "\\0"; + } else { + helper.put(*ch); + } + } + + return String(helper.str().c_str()); +} + +// Appends the user-supplied message to the Google-Test-generated message. +String AppendUserMessage(const String& gtest_msg, + const Message& user_msg) { + // Appends the user message if it's non-empty. + const String user_msg_string = user_msg.GetString(); + if (user_msg_string.empty()) { + return gtest_msg; + } + + Message msg; + msg << gtest_msg << "\n" << user_msg_string; + + return msg.GetString(); +} + +} // namespace internal + +// class TestResult + +// Creates an empty TestResult. +TestResult::TestResult() + : death_test_count_(0), + elapsed_time_(0) { +} + +// D'tor. +TestResult::~TestResult() { +} + +// Returns the i-th test part result among all the results. i can +// range from 0 to total_part_count() - 1. If i is not in that range, +// aborts the program. +const TestPartResult& TestResult::GetTestPartResult(int i) const { + if (i < 0 || i >= total_part_count()) + internal::posix::Abort(); + return test_part_results_.at(i); +} + +// Returns the i-th test property. i can range from 0 to +// test_property_count() - 1. If i is not in that range, aborts the +// program. +const TestProperty& TestResult::GetTestProperty(int i) const { + if (i < 0 || i >= test_property_count()) + internal::posix::Abort(); + return test_properties_.at(i); +} + +// Clears the test part results. +void TestResult::ClearTestPartResults() { + test_part_results_.clear(); +} + +// Adds a test part result to the list. +void TestResult::AddTestPartResult(const TestPartResult& test_part_result) { + test_part_results_.push_back(test_part_result); +} + +// Adds a test property to the list. If a property with the same key as the +// supplied property is already represented, the value of this test_property +// replaces the old value for that key. +void TestResult::RecordProperty(const TestProperty& test_property) { + if (!ValidateTestProperty(test_property)) { + return; + } + internal::MutexLock lock(&test_properites_mutex_); + const std::vector::iterator property_with_matching_key = + std::find_if(test_properties_.begin(), test_properties_.end(), + internal::TestPropertyKeyIs(test_property.key())); + if (property_with_matching_key == test_properties_.end()) { + test_properties_.push_back(test_property); + return; + } + property_with_matching_key->SetValue(test_property.value()); +} + +// Adds a failure if the key is a reserved attribute of Google Test +// testcase tags. Returns true if the property is valid. +bool TestResult::ValidateTestProperty(const TestProperty& test_property) { + internal::String key(test_property.key()); + if (key == "name" || key == "status" || key == "time" || key == "classname") { + ADD_FAILURE() + << "Reserved key used in RecordProperty(): " + << key + << " ('name', 'status', 'time', and 'classname' are reserved by " + << GTEST_NAME_ << ")"; + return false; + } + return true; +} + +// Clears the object. +void TestResult::Clear() { + test_part_results_.clear(); + test_properties_.clear(); + death_test_count_ = 0; + elapsed_time_ = 0; +} + +// Returns true iff the test failed. +bool TestResult::Failed() const { + for (int i = 0; i < total_part_count(); ++i) { + if (GetTestPartResult(i).failed()) + return true; + } + return false; +} + +// Returns true iff the test part fatally failed. +static bool TestPartFatallyFailed(const TestPartResult& result) { + return result.fatally_failed(); +} + +// Returns true iff the test fatally failed. +bool TestResult::HasFatalFailure() const { + return CountIf(test_part_results_, TestPartFatallyFailed) > 0; +} + +// Returns true iff the test part non-fatally failed. +static bool TestPartNonfatallyFailed(const TestPartResult& result) { + return result.nonfatally_failed(); +} + +// Returns true iff the test has a non-fatal failure. +bool TestResult::HasNonfatalFailure() const { + return CountIf(test_part_results_, TestPartNonfatallyFailed) > 0; +} + +// Gets the number of all test parts. This is the sum of the number +// of successful test parts and the number of failed test parts. +int TestResult::total_part_count() const { + return static_cast(test_part_results_.size()); +} + +// Returns the number of the test properties. +int TestResult::test_property_count() const { + return static_cast(test_properties_.size()); +} + +// class Test + +// Creates a Test object. + +// The c'tor saves the values of all Google Test flags. +Test::Test() + : gtest_flag_saver_(new internal::GTestFlagSaver) { +} + +// The d'tor restores the values of all Google Test flags. +Test::~Test() { + delete gtest_flag_saver_; +} + +// Sets up the test fixture. +// +// A sub-class may override this. +void Test::SetUp() { +} + +// Tears down the test fixture. +// +// A sub-class may override this. +void Test::TearDown() { +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const char* key, const char* value) { + UnitTest::GetInstance()->RecordPropertyForCurrentTest(key, value); +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const char* key, int value) { + Message value_message; + value_message << value; + RecordProperty(key, value_message.GetString().c_str()); +} + +namespace internal { + +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const String& message) { + // This function is a friend of UnitTest and as such has access to + // AddTestPartResult. + UnitTest::GetInstance()->AddTestPartResult( + result_type, + NULL, // No info about the source file where the exception occurred. + -1, // We have no info on which line caused the exception. + message, + String()); // No stack trace, either. +} + +} // namespace internal + +// Google Test requires all tests in the same test case to use the same test +// fixture class. This function checks if the current test has the +// same fixture class as the first test in the current test case. If +// yes, it returns true; otherwise it generates a Google Test failure and +// returns false. +bool Test::HasSameFixtureClass() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + const TestCase* const test_case = impl->current_test_case(); + + // Info about the first test in the current test case. + const TestInfo* const first_test_info = test_case->test_info_list()[0]; + const internal::TypeId first_fixture_id = first_test_info->fixture_class_id_; + const char* const first_test_name = first_test_info->name(); + + // Info about the current test. + const TestInfo* const this_test_info = impl->current_test_info(); + const internal::TypeId this_fixture_id = this_test_info->fixture_class_id_; + const char* const this_test_name = this_test_info->name(); + + if (this_fixture_id != first_fixture_id) { + // Is the first test defined using TEST? + const bool first_is_TEST = first_fixture_id == internal::GetTestTypeId(); + // Is this test defined using TEST? + const bool this_is_TEST = this_fixture_id == internal::GetTestTypeId(); + + if (first_is_TEST || this_is_TEST) { + // The user mixed TEST and TEST_F in this test case - we'll tell + // him/her how to fix it. + + // Gets the name of the TEST and the name of the TEST_F. Note + // that first_is_TEST and this_is_TEST cannot both be true, as + // the fixture IDs are different for the two tests. + const char* const TEST_name = + first_is_TEST ? first_test_name : this_test_name; + const char* const TEST_F_name = + first_is_TEST ? this_test_name : first_test_name; + + ADD_FAILURE() + << "All tests in the same test case must use the same test fixture\n" + << "class, so mixing TEST_F and TEST in the same test case is\n" + << "illegal. In test case " << this_test_info->test_case_name() + << ",\n" + << "test " << TEST_F_name << " is defined using TEST_F but\n" + << "test " << TEST_name << " is defined using TEST. You probably\n" + << "want to change the TEST to TEST_F or move it to another test\n" + << "case."; + } else { + // The user defined two fixture classes with the same name in + // two namespaces - we'll tell him/her how to fix it. + ADD_FAILURE() + << "All tests in the same test case must use the same test fixture\n" + << "class. However, in test case " + << this_test_info->test_case_name() << ",\n" + << "you defined test " << first_test_name + << " and test " << this_test_name << "\n" + << "using two different test fixture classes. This can happen if\n" + << "the two classes are from different namespaces or translation\n" + << "units and have the same name. You should probably rename one\n" + << "of the classes to put the tests into different test cases."; + } + return false; + } + + return true; +} + +#if GTEST_HAS_SEH + +// Adds an "exception thrown" fatal failure to the current test. This +// function returns its result via an output parameter pointer because VC++ +// prohibits creation of objects with destructors on stack in functions +// using __try (see error C2712). +static internal::String* FormatSehExceptionMessage(DWORD exception_code, + const char* location) { + Message message; + message << "SEH exception with code 0x" << std::setbase(16) << + exception_code << std::setbase(10) << " thrown in " << location << "."; + + return new internal::String(message.GetString()); +} + +#endif // GTEST_HAS_SEH + +#if GTEST_HAS_EXCEPTIONS + +// Adds an "exception thrown" fatal failure to the current test. +static internal::String FormatCxxExceptionMessage(const char* description, + const char* location) { + Message message; + if (description != NULL) { + message << "C++ exception with description \"" << description << "\""; + } else { + message << "Unknown C++ exception"; + } + message << " thrown in " << location << "."; + + return message.GetString(); +} + +static internal::String PrintTestPartResultToString( + const TestPartResult& test_part_result); + +// A failed Google Test assertion will throw an exception of this type when +// GTEST_FLAG(throw_on_failure) is true (if exceptions are enabled). We +// derive it from std::runtime_error, which is for errors presumably +// detectable only at run time. Since std::runtime_error inherits from +// std::exception, many testing frameworks know how to extract and print the +// message inside it. +class GoogleTestFailureException : public ::std::runtime_error { + public: + explicit GoogleTestFailureException(const TestPartResult& failure) + : ::std::runtime_error(PrintTestPartResultToString(failure).c_str()) {} +}; +#endif // GTEST_HAS_EXCEPTIONS + +namespace internal { +// We put these helper functions in the internal namespace as IBM's xlC +// compiler rejects the code if they were declared static. + +// Runs the given method and handles SEH exceptions it throws, when +// SEH is supported; returns the 0-value for type Result in case of an +// SEH exception. (Microsoft compilers cannot handle SEH and C++ +// exceptions in the same function. Therefore, we provide a separate +// wrapper function for handling SEH exceptions.) +template +Result HandleSehExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { +#if GTEST_HAS_SEH + __try { + return (object->*method)(); + } __except (internal::UnitTestOptions::GTestShouldProcessSEH( // NOLINT + GetExceptionCode())) { + // We create the exception message on the heap because VC++ prohibits + // creation of objects with destructors on stack in functions using __try + // (see error C2712). + internal::String* exception_message = FormatSehExceptionMessage( + GetExceptionCode(), location); + internal::ReportFailureInUnknownLocation(TestPartResult::kFatalFailure, + *exception_message); + delete exception_message; + return static_cast(0); + } +#else + (void)location; + return (object->*method)(); +#endif // GTEST_HAS_SEH +} + +// Runs the given method and catches and reports C++ and/or SEH-style +// exceptions, if they are supported; returns the 0-value for type +// Result in case of an SEH exception. +template +Result HandleExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { + // NOTE: The user code can affect the way in which Google Test handles + // exceptions by setting GTEST_FLAG(catch_exceptions), but only before + // RUN_ALL_TESTS() starts. It is technically possible to check the flag + // after the exception is caught and either report or re-throw the + // exception based on the flag's value: + // + // try { + // // Perform the test method. + // } catch (...) { + // if (GTEST_FLAG(catch_exceptions)) + // // Report the exception as failure. + // else + // throw; // Re-throws the original exception. + // } + // + // However, the purpose of this flag is to allow the program to drop into + // the debugger when the exception is thrown. On most platforms, once the + // control enters the catch block, the exception origin information is + // lost and the debugger will stop the program at the point of the + // re-throw in this function -- instead of at the point of the original + // throw statement in the code under test. For this reason, we perform + // the check early, sacrificing the ability to affect Google Test's + // exception handling in the method where the exception is thrown. + if (internal::GetUnitTestImpl()->catch_exceptions()) { +#if GTEST_HAS_EXCEPTIONS + try { + return HandleSehExceptionsInMethodIfSupported(object, method, location); + } catch (const GoogleTestFailureException&) { // NOLINT + // This exception doesn't originate in code under test. It makes no + // sense to report it as a test failure. + throw; + } catch (const std::exception& e) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(e.what(), location)); + } catch (...) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(NULL, location)); + } + return static_cast(0); +#else + return HandleSehExceptionsInMethodIfSupported(object, method, location); +#endif // GTEST_HAS_EXCEPTIONS + } else { + return (object->*method)(); + } +} + +} // namespace internal + +// Runs the test and updates the test result. +void Test::Run() { + if (!HasSameFixtureClass()) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported(this, &Test::SetUp, "SetUp()"); + // We will run the test only if SetUp() was successful. + if (!HasFatalFailure()) { + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TestBody, "the test body"); + } + + // However, we want to clean up as much as possible. Hence we will + // always call TearDown(), even if SetUp() or the test body has + // failed. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TearDown, "TearDown()"); +} + +// Returns true iff the current test has a fatal failure. +bool Test::HasFatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()->HasFatalFailure(); +} + +// Returns true iff the current test has a non-fatal failure. +bool Test::HasNonfatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()-> + HasNonfatalFailure(); +} + +// class TestInfo + +// Constructs a TestInfo object. It assumes ownership of the test factory +// object. +// TODO(vladl@google.com): Make a_test_case_name and a_name const string&'s +// to signify they cannot be NULLs. +TestInfo::TestInfo(const char* a_test_case_name, + const char* a_name, + const char* a_type_param, + const char* a_value_param, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory) + : test_case_name_(a_test_case_name), + name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : NULL), + value_param_(a_value_param ? new std::string(a_value_param) : NULL), + fixture_class_id_(fixture_class_id), + should_run_(false), + is_disabled_(false), + matches_filter_(false), + factory_(factory), + result_() {} + +// Destructs a TestInfo object. +TestInfo::~TestInfo() { delete factory_; } + +namespace internal { + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_case_name: name of the test case +// name: name of the test +// type_param: the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param: text representation of the test's value parameter, +// or NULL if this is not a value-parameterized test. +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +TestInfo* MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + TypeId fixture_class_id, + SetUpTestCaseFunc set_up_tc, + TearDownTestCaseFunc tear_down_tc, + TestFactoryBase* factory) { + TestInfo* const test_info = + new TestInfo(test_case_name, name, type_param, value_param, + fixture_class_id, factory); + GetUnitTestImpl()->AddTestInfo(set_up_tc, tear_down_tc, test_info); + return test_info; +} + +#if GTEST_HAS_PARAM_TEST +void ReportInvalidTestCaseType(const char* test_case_name, + const char* file, int line) { + Message errors; + errors + << "Attempted redefinition of test case " << test_case_name << ".\n" + << "All tests in the same test case must use the same test fixture\n" + << "class. However, in test case " << test_case_name << ", you tried\n" + << "to define a test using a fixture class different from the one\n" + << "used earlier. This can happen if the two fixture classes are\n" + << "from different namespaces and have the same name. You should\n" + << "probably rename one of the classes to put the tests into different\n" + << "test cases."; + + fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), + errors.GetString().c_str()); +} +#endif // GTEST_HAS_PARAM_TEST + +} // namespace internal + +namespace { + +// A predicate that checks the test name of a TestInfo against a known +// value. +// +// This is used for implementation of the TestCase class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestNameIs is copyable. +class TestNameIs { + public: + // Constructor. + // + // TestNameIs has NO default constructor. + explicit TestNameIs(const char* name) + : name_(name) {} + + // Returns true iff the test name of test_info matches name_. + bool operator()(const TestInfo * test_info) const { + return test_info && internal::String(test_info->name()).Compare(name_) == 0; + } + + private: + internal::String name_; +}; + +} // namespace + +namespace internal { + +// This method expands all parameterized tests registered with macros TEST_P +// and INSTANTIATE_TEST_CASE_P into regular tests and registers those. +// This will be done just once during the program runtime. +void UnitTestImpl::RegisterParameterizedTests() { +#if GTEST_HAS_PARAM_TEST + if (!parameterized_tests_registered_) { + parameterized_test_registry_.RegisterTests(); + parameterized_tests_registered_ = true; + } +#endif +} + +} // namespace internal + +// Creates the test object, runs it, records its result, and then +// deletes it. +void TestInfo::Run() { + if (!should_run_) return; + + // Tells UnitTest where to store test result. + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_info(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + // Notifies the unit test event listeners that a test is about to start. + repeater->OnTestStart(*this); + + const TimeInMillis start = internal::GetTimeInMillis(); + + impl->os_stack_trace_getter()->UponLeavingGTest(); + + // Creates the test object. + Test* const test = internal::HandleExceptionsInMethodIfSupported( + factory_, &internal::TestFactoryBase::CreateTest, + "the test fixture's constructor"); + + // Runs the test only if the test object was created and its + // constructor didn't generate a fatal failure. + if ((test != NULL) && !Test::HasFatalFailure()) { + // This doesn't throw as all user code that can throw are wrapped into + // exception handling code. + test->Run(); + } + + // Deletes the test object. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + test, &Test::DeleteSelf_, "the test fixture's destructor"); + + result_.set_elapsed_time(internal::GetTimeInMillis() - start); + + // Notifies the unit test event listener that a test has just finished. + repeater->OnTestEnd(*this); + + // Tells UnitTest to stop associating assertion results to this + // test. + impl->set_current_test_info(NULL); +} + +// class TestCase + +// Gets the number of successful tests in this test case. +int TestCase::successful_test_count() const { + return CountIf(test_info_list_, TestPassed); +} + +// Gets the number of failed tests in this test case. +int TestCase::failed_test_count() const { + return CountIf(test_info_list_, TestFailed); +} + +int TestCase::disabled_test_count() const { + return CountIf(test_info_list_, TestDisabled); +} + +// Get the number of tests in this test case that should run. +int TestCase::test_to_run_count() const { + return CountIf(test_info_list_, ShouldRunTest); +} + +// Gets the number of all tests. +int TestCase::total_test_count() const { + return static_cast(test_info_list_.size()); +} + +// Creates a TestCase with the given name. +// +// Arguments: +// +// name: name of the test case +// a_type_param: the name of the test case's type parameter, or NULL if +// this is not a typed or a type-parameterized test case. +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +TestCase::TestCase(const char* a_name, const char* a_type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc) + : name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : NULL), + set_up_tc_(set_up_tc), + tear_down_tc_(tear_down_tc), + should_run_(false), + elapsed_time_(0) { +} + +// Destructor of TestCase. +TestCase::~TestCase() { + // Deletes every Test in the collection. + ForEach(test_info_list_, internal::Delete); +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +const TestInfo* TestCase::GetTestInfo(int i) const { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? NULL : test_info_list_[index]; +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +TestInfo* TestCase::GetMutableTestInfo(int i) { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? NULL : test_info_list_[index]; +} + +// Adds a test to this test case. Will delete the test upon +// destruction of the TestCase object. +void TestCase::AddTestInfo(TestInfo * test_info) { + test_info_list_.push_back(test_info); + test_indices_.push_back(static_cast(test_indices_.size())); +} + +// Runs every test in this TestCase. +void TestCase::Run() { + if (!should_run_) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_case(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + repeater->OnTestCaseStart(*this); + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestCase::RunSetUpTestCase, "SetUpTestCase()"); + + const internal::TimeInMillis start = internal::GetTimeInMillis(); + for (int i = 0; i < total_test_count(); i++) { + GetMutableTestInfo(i)->Run(); + } + elapsed_time_ = internal::GetTimeInMillis() - start; + + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestCase::RunTearDownTestCase, "TearDownTestCase()"); + + repeater->OnTestCaseEnd(*this); + impl->set_current_test_case(NULL); +} + +// Clears the results of all tests in this test case. +void TestCase::ClearResult() { + ForEach(test_info_list_, TestInfo::ClearTestResult); +} + +// Shuffles the tests in this test case. +void TestCase::ShuffleTests(internal::Random* random) { + Shuffle(random, &test_indices_); +} + +// Restores the test order to before the first shuffle. +void TestCase::UnshuffleTests() { + for (size_t i = 0; i < test_indices_.size(); i++) { + test_indices_[i] = static_cast(i); + } +} + +// Formats a countable noun. Depending on its quantity, either the +// singular form or the plural form is used. e.g. +// +// FormatCountableNoun(1, "formula", "formuli") returns "1 formula". +// FormatCountableNoun(5, "book", "books") returns "5 books". +static internal::String FormatCountableNoun(int count, + const char * singular_form, + const char * plural_form) { + return internal::String::Format("%d %s", count, + count == 1 ? singular_form : plural_form); +} + +// Formats the count of tests. +static internal::String FormatTestCount(int test_count) { + return FormatCountableNoun(test_count, "test", "tests"); +} + +// Formats the count of test cases. +static internal::String FormatTestCaseCount(int test_case_count) { + return FormatCountableNoun(test_case_count, "test case", "test cases"); +} + +// Converts a TestPartResult::Type enum to human-friendly string +// representation. Both kNonFatalFailure and kFatalFailure are translated +// to "Failure", as the user usually doesn't care about the difference +// between the two when viewing the test result. +static const char * TestPartResultTypeToString(TestPartResult::Type type) { + switch (type) { + case TestPartResult::kSuccess: + return "Success"; + + case TestPartResult::kNonFatalFailure: + case TestPartResult::kFatalFailure: +#ifdef _MSC_VER + return "error: "; +#else + return "Failure\n"; +#endif + default: + return "Unknown result type"; + } +} + +// Prints a TestPartResult to a String. +static internal::String PrintTestPartResultToString( + const TestPartResult& test_part_result) { + return (Message() + << internal::FormatFileLocation(test_part_result.file_name(), + test_part_result.line_number()) + << " " << TestPartResultTypeToString(test_part_result.type()) + << test_part_result.message()).GetString(); +} + +// Prints a TestPartResult. +static void PrintTestPartResult(const TestPartResult& test_part_result) { + const internal::String& result = + PrintTestPartResultToString(test_part_result); + printf("%s\n", result.c_str()); + fflush(stdout); + // If the test program runs in Visual Studio or a debugger, the + // following statements add the test part result message to the Output + // window such that the user can double-click on it to jump to the + // corresponding source code location; otherwise they do nothing. +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + // We don't call OutputDebugString*() on Windows Mobile, as printing + // to stdout is done by OutputDebugString() there already - we don't + // want the same message printed twice. + ::OutputDebugStringA(result.c_str()); + ::OutputDebugStringA("\n"); +#endif +} + +// class PrettyUnitTestResultPrinter + +namespace internal { + +enum GTestColor { + COLOR_DEFAULT, + COLOR_RED, + COLOR_GREEN, + COLOR_YELLOW +}; + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + +// Returns the character attribute for the given color. +WORD GetColorAttribute(GTestColor color) { + switch (color) { + case COLOR_RED: return FOREGROUND_RED; + case COLOR_GREEN: return FOREGROUND_GREEN; + case COLOR_YELLOW: return FOREGROUND_RED | FOREGROUND_GREEN; + default: return 0; + } +} + +#else + +// Returns the ANSI color code for the given color. COLOR_DEFAULT is +// an invalid input. +const char* GetAnsiColorCode(GTestColor color) { + switch (color) { + case COLOR_RED: return "1"; + case COLOR_GREEN: return "2"; + case COLOR_YELLOW: return "3"; + default: return NULL; + }; +} + +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + +// Returns true iff Google Test should use colors in the output. +bool ShouldUseColor(bool stdout_is_tty) { + const char* const gtest_color = GTEST_FLAG(color).c_str(); + + if (String::CaseInsensitiveCStringEquals(gtest_color, "auto")) { +#if GTEST_OS_WINDOWS + // On Windows the TERM variable is usually not set, but the + // console there does support colors. + return stdout_is_tty; +#else + // On non-Windows platforms, we rely on the TERM variable. + const char* const term = posix::GetEnv("TERM"); + const bool term_supports_color = + String::CStringEquals(term, "xterm") || + String::CStringEquals(term, "xterm-color") || + String::CStringEquals(term, "xterm-256color") || + String::CStringEquals(term, "screen") || + String::CStringEquals(term, "linux") || + String::CStringEquals(term, "cygwin"); + return stdout_is_tty && term_supports_color; +#endif // GTEST_OS_WINDOWS + } + + return String::CaseInsensitiveCStringEquals(gtest_color, "yes") || + String::CaseInsensitiveCStringEquals(gtest_color, "true") || + String::CaseInsensitiveCStringEquals(gtest_color, "t") || + String::CStringEquals(gtest_color, "1"); + // We take "yes", "true", "t", and "1" as meaning "yes". If the + // value is neither one of these nor "auto", we treat it as "no" to + // be conservative. +} + +// Helpers for printing colored strings to stdout. Note that on Windows, we +// cannot simply emit special characters and have the terminal change colors. +// This routine must actually emit the characters rather than return a string +// that would be colored when printed, as can be done on Linux. +void ColoredPrintf(GTestColor color, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN || GTEST_OS_ZOS + const bool use_color = false; +#else + static const bool in_color_mode = + ShouldUseColor(posix::IsATTY(posix::FileNo(stdout)) != 0); + const bool use_color = in_color_mode && (color != COLOR_DEFAULT); +#endif // GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN || GTEST_OS_ZOS + // The '!= 0' comparison is necessary to satisfy MSVC 7.1. + + if (!use_color) { + vprintf(fmt, args); + va_end(args); + return; + } + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + const HANDLE stdout_handle = GetStdHandle(STD_OUTPUT_HANDLE); + + // Gets the current text color. + CONSOLE_SCREEN_BUFFER_INFO buffer_info; + GetConsoleScreenBufferInfo(stdout_handle, &buffer_info); + const WORD old_color_attrs = buffer_info.wAttributes; + + // We need to flush the stream buffers into the console before each + // SetConsoleTextAttribute call lest it affect the text that is already + // printed but has not yet reached the console. + fflush(stdout); + SetConsoleTextAttribute(stdout_handle, + GetColorAttribute(color) | FOREGROUND_INTENSITY); + vprintf(fmt, args); + + fflush(stdout); + // Restores the text color. + SetConsoleTextAttribute(stdout_handle, old_color_attrs); +#else + printf("\033[0;3%sm", GetAnsiColorCode(color)); + vprintf(fmt, args); + printf("\033[m"); // Resets the terminal to default. +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + va_end(args); +} + +void PrintFullTestCommentIfPresent(const TestInfo& test_info) { + const char* const type_param = test_info.type_param(); + const char* const value_param = test_info.value_param(); + + if (type_param != NULL || value_param != NULL) { + printf(", where "); + if (type_param != NULL) { + printf("TypeParam = %s", type_param); + if (value_param != NULL) + printf(" and "); + } + if (value_param != NULL) { + printf("GetParam() = %s", value_param); + } + } +} + +// This class implements the TestEventListener interface. +// +// Class PrettyUnitTestResultPrinter is copyable. +class PrettyUnitTestResultPrinter : public TestEventListener { + public: + PrettyUnitTestResultPrinter() {} + static void PrintTestName(const char * test_case, const char * test) { + printf("%s.%s", test_case, test); + } + + // The following methods override what's in the TestEventListener class. + virtual void OnTestProgramStart(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationStart(const UnitTest& unit_test, int iteration); + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test); + virtual void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestCaseStart(const TestCase& test_case); + virtual void OnTestStart(const TestInfo& test_info); + virtual void OnTestPartResult(const TestPartResult& result); + virtual void OnTestEnd(const TestInfo& test_info); + virtual void OnTestCaseEnd(const TestCase& test_case); + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test); + virtual void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + virtual void OnTestProgramEnd(const UnitTest& /*unit_test*/) {} + + private: + static void PrintFailedTests(const UnitTest& unit_test); + + internal::String test_case_name_; +}; + + // Fired before each iteration of tests starts. +void PrettyUnitTestResultPrinter::OnTestIterationStart( + const UnitTest& unit_test, int iteration) { + if (GTEST_FLAG(repeat) != 1) + printf("\nRepeating all tests (iteration %d) . . .\n\n", iteration + 1); + + const char* const filter = GTEST_FLAG(filter).c_str(); + + // Prints the filter if it's not *. This reminds the user that some + // tests may be skipped. + if (!internal::String::CStringEquals(filter, kUniversalFilter)) { + ColoredPrintf(COLOR_YELLOW, + "Note: %s filter = %s\n", GTEST_NAME_, filter); + } + + if (internal::ShouldShard(kTestTotalShards, kTestShardIndex, false)) { + const Int32 shard_index = Int32FromEnvOrDie(kTestShardIndex, -1); + ColoredPrintf(COLOR_YELLOW, + "Note: This is test shard %d of %s.\n", + static_cast(shard_index) + 1, + internal::posix::GetEnv(kTestTotalShards)); + } + + if (GTEST_FLAG(shuffle)) { + ColoredPrintf(COLOR_YELLOW, + "Note: Randomizing tests' orders with a seed of %d .\n", + unit_test.random_seed()); + } + + ColoredPrintf(COLOR_GREEN, "[==========] "); + printf("Running %s from %s.\n", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestCaseCount(unit_test.test_case_to_run_count()).c_str()); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnEnvironmentsSetUpStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("Global test environment set-up.\n"); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestCaseStart(const TestCase& test_case) { + test_case_name_ = test_case.name(); + const internal::String counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("%s from %s", counts.c_str(), test_case_name_.c_str()); + if (test_case.type_param() == NULL) { + printf("\n"); + } else { + printf(", where TypeParam = %s\n", test_case.type_param()); + } + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestStart(const TestInfo& test_info) { + ColoredPrintf(COLOR_GREEN, "[ RUN ] "); + PrintTestName(test_case_name_.c_str(), test_info.name()); + printf("\n"); + fflush(stdout); +} + +// Called after an assertion failure. +void PrettyUnitTestResultPrinter::OnTestPartResult( + const TestPartResult& result) { + // If the test part succeeded, we don't need to do anything. + if (result.type() == TestPartResult::kSuccess) + return; + + // Print failure message from the assertion (e.g. expected this and got that). + PrintTestPartResult(result); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestEnd(const TestInfo& test_info) { + if (test_info.result()->Passed()) { + ColoredPrintf(COLOR_GREEN, "[ OK ] "); + } else { + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + } + PrintTestName(test_case_name_.c_str(), test_info.name()); + if (test_info.result()->Failed()) + PrintFullTestCommentIfPresent(test_info); + + if (GTEST_FLAG(print_time)) { + printf(" (%s ms)\n", internal::StreamableToString( + test_info.result()->elapsed_time()).c_str()); + } else { + printf("\n"); + } + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestCaseEnd(const TestCase& test_case) { + if (!GTEST_FLAG(print_time)) return; + + test_case_name_ = test_case.name(); + const internal::String counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("%s from %s (%s ms total)\n\n", + counts.c_str(), test_case_name_.c_str(), + internal::StreamableToString(test_case.elapsed_time()).c_str()); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnEnvironmentsTearDownStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("Global test environment tear-down\n"); + fflush(stdout); +} + +// Internal helper for printing the list of failed tests. +void PrettyUnitTestResultPrinter::PrintFailedTests(const UnitTest& unit_test) { + const int failed_test_count = unit_test.failed_test_count(); + if (failed_test_count == 0) { + return; + } + + for (int i = 0; i < unit_test.total_test_case_count(); ++i) { + const TestCase& test_case = *unit_test.GetTestCase(i); + if (!test_case.should_run() || (test_case.failed_test_count() == 0)) { + continue; + } + for (int j = 0; j < test_case.total_test_count(); ++j) { + const TestInfo& test_info = *test_case.GetTestInfo(j); + if (!test_info.should_run() || test_info.result()->Passed()) { + continue; + } + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + printf("%s.%s", test_case.name(), test_info.name()); + PrintFullTestCommentIfPresent(test_info); + printf("\n"); + } + } +} + +void PrettyUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + ColoredPrintf(COLOR_GREEN, "[==========] "); + printf("%s from %s ran.", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestCaseCount(unit_test.test_case_to_run_count()).c_str()); + if (GTEST_FLAG(print_time)) { + printf(" (%s ms total)", + internal::StreamableToString(unit_test.elapsed_time()).c_str()); + } + printf("\n"); + ColoredPrintf(COLOR_GREEN, "[ PASSED ] "); + printf("%s.\n", FormatTestCount(unit_test.successful_test_count()).c_str()); + + int num_failures = unit_test.failed_test_count(); + if (!unit_test.Passed()) { + const int failed_test_count = unit_test.failed_test_count(); + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + printf("%s, listed below:\n", FormatTestCount(failed_test_count).c_str()); + PrintFailedTests(unit_test); + printf("\n%2d FAILED %s\n", num_failures, + num_failures == 1 ? "TEST" : "TESTS"); + } + + int num_disabled = unit_test.disabled_test_count(); + if (num_disabled && !GTEST_FLAG(also_run_disabled_tests)) { + if (!num_failures) { + printf("\n"); // Add a spacer if no FAILURE banner is displayed. + } + ColoredPrintf(COLOR_YELLOW, + " YOU HAVE %d DISABLED %s\n\n", + num_disabled, + num_disabled == 1 ? "TEST" : "TESTS"); + } + // Ensure that Google Test output is printed before, e.g., heapchecker output. + fflush(stdout); +} + +// End PrettyUnitTestResultPrinter + +// class TestEventRepeater +// +// This class forwards events to other event listeners. +class TestEventRepeater : public TestEventListener { + public: + TestEventRepeater() : forwarding_enabled_(true) {} + virtual ~TestEventRepeater(); + void Append(TestEventListener *listener); + TestEventListener* Release(TestEventListener* listener); + + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled() const { return forwarding_enabled_; } + void set_forwarding_enabled(bool enable) { forwarding_enabled_ = enable; } + + virtual void OnTestProgramStart(const UnitTest& unit_test); + virtual void OnTestIterationStart(const UnitTest& unit_test, int iteration); + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test); + virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test); + virtual void OnTestCaseStart(const TestCase& test_case); + virtual void OnTestStart(const TestInfo& test_info); + virtual void OnTestPartResult(const TestPartResult& result); + virtual void OnTestEnd(const TestInfo& test_info); + virtual void OnTestCaseEnd(const TestCase& test_case); + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test); + virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test); + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + virtual void OnTestProgramEnd(const UnitTest& unit_test); + + private: + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled_; + // The list of listeners that receive events. + std::vector listeners_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventRepeater); +}; + +TestEventRepeater::~TestEventRepeater() { + ForEach(listeners_, Delete); +} + +void TestEventRepeater::Append(TestEventListener *listener) { + listeners_.push_back(listener); +} + +// TODO(vladl@google.com): Factor the search functionality into Vector::Find. +TestEventListener* TestEventRepeater::Release(TestEventListener *listener) { + for (size_t i = 0; i < listeners_.size(); ++i) { + if (listeners_[i] == listener) { + listeners_.erase(listeners_.begin() + i); + return listener; + } + } + + return NULL; +} + +// Since most methods are very similar, use macros to reduce boilerplate. +// This defines a member that forwards the call to all listeners. +#define GTEST_REPEATER_METHOD_(Name, Type) \ +void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (size_t i = 0; i < listeners_.size(); i++) { \ + listeners_[i]->Name(parameter); \ + } \ + } \ +} +// This defines a member that forwards the call to all listeners in reverse +// order. +#define GTEST_REVERSE_REPEATER_METHOD_(Name, Type) \ +void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (int i = static_cast(listeners_.size()) - 1; i >= 0; i--) { \ + listeners_[i]->Name(parameter); \ + } \ + } \ +} + +GTEST_REPEATER_METHOD_(OnTestProgramStart, UnitTest) +GTEST_REPEATER_METHOD_(OnEnvironmentsSetUpStart, UnitTest) +GTEST_REPEATER_METHOD_(OnTestCaseStart, TestCase) +GTEST_REPEATER_METHOD_(OnTestStart, TestInfo) +GTEST_REPEATER_METHOD_(OnTestPartResult, TestPartResult) +GTEST_REPEATER_METHOD_(OnEnvironmentsTearDownStart, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsSetUpEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsTearDownEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnTestEnd, TestInfo) +GTEST_REVERSE_REPEATER_METHOD_(OnTestCaseEnd, TestCase) +GTEST_REVERSE_REPEATER_METHOD_(OnTestProgramEnd, UnitTest) + +#undef GTEST_REPEATER_METHOD_ +#undef GTEST_REVERSE_REPEATER_METHOD_ + +void TestEventRepeater::OnTestIterationStart(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (size_t i = 0; i < listeners_.size(); i++) { + listeners_[i]->OnTestIterationStart(unit_test, iteration); + } + } +} + +void TestEventRepeater::OnTestIterationEnd(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (int i = static_cast(listeners_.size()) - 1; i >= 0; i--) { + listeners_[i]->OnTestIterationEnd(unit_test, iteration); + } + } +} + +// End TestEventRepeater + +// This class generates an XML output file. +class XmlUnitTestResultPrinter : public EmptyTestEventListener { + public: + explicit XmlUnitTestResultPrinter(const char* output_file); + + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + + private: + // Is c a whitespace character that is normalized to a space character + // when it appears in an XML attribute value? + static bool IsNormalizableWhitespace(char c) { + return c == 0x9 || c == 0xA || c == 0xD; + } + + // May c appear in a well-formed XML document? + static bool IsValidXmlCharacter(char c) { + return IsNormalizableWhitespace(c) || c >= 0x20; + } + + // Returns an XML-escaped copy of the input string str. If + // is_attribute is true, the text is meant to appear as an attribute + // value, and normalizable whitespace is preserved by replacing it + // with character references. + static String EscapeXml(const char* str, bool is_attribute); + + // Returns the given string with all characters invalid in XML removed. + static string RemoveInvalidXmlCharacters(const string& str); + + // Convenience wrapper around EscapeXml when str is an attribute value. + static String EscapeXmlAttribute(const char* str) { + return EscapeXml(str, true); + } + + // Convenience wrapper around EscapeXml when str is not an attribute value. + static String EscapeXmlText(const char* str) { return EscapeXml(str, false); } + + // Streams an XML CDATA section, escaping invalid CDATA sequences as needed. + static void OutputXmlCDataSection(::std::ostream* stream, const char* data); + + // Streams an XML representation of a TestInfo object. + static void OutputXmlTestInfo(::std::ostream* stream, + const char* test_case_name, + const TestInfo& test_info); + + // Prints an XML representation of a TestCase object + static void PrintXmlTestCase(FILE* out, const TestCase& test_case); + + // Prints an XML summary of unit_test to output stream out. + static void PrintXmlUnitTest(FILE* out, const UnitTest& unit_test); + + // Produces a string representing the test properties in a result as space + // delimited XML attributes based on the property key="value" pairs. + // When the String is not empty, it includes a space at the beginning, + // to delimit this attribute from prior attributes. + static String TestPropertiesAsXmlAttributes(const TestResult& result); + + // The output file. + const String output_file_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(XmlUnitTestResultPrinter); +}; + +// Creates a new XmlUnitTestResultPrinter. +XmlUnitTestResultPrinter::XmlUnitTestResultPrinter(const char* output_file) + : output_file_(output_file) { + if (output_file_.c_str() == NULL || output_file_.empty()) { + fprintf(stderr, "XML output file may not be null\n"); + fflush(stderr); + exit(EXIT_FAILURE); + } +} + +// Called after the unit test ends. +void XmlUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + FILE* xmlout = NULL; + FilePath output_file(output_file_); + FilePath output_dir(output_file.RemoveFileName()); + + if (output_dir.CreateDirectoriesRecursively()) { + xmlout = posix::FOpen(output_file_.c_str(), "w"); + } + if (xmlout == NULL) { + // TODO(wan): report the reason of the failure. + // + // We don't do it for now as: + // + // 1. There is no urgent need for it. + // 2. It's a bit involved to make the errno variable thread-safe on + // all three operating systems (Linux, Windows, and Mac OS). + // 3. To interpret the meaning of errno in a thread-safe way, + // we need the strerror_r() function, which is not available on + // Windows. + fprintf(stderr, + "Unable to open file \"%s\"\n", + output_file_.c_str()); + fflush(stderr); + exit(EXIT_FAILURE); + } + PrintXmlUnitTest(xmlout, unit_test); + fclose(xmlout); +} + +// Returns an XML-escaped copy of the input string str. If is_attribute +// is true, the text is meant to appear as an attribute value, and +// normalizable whitespace is preserved by replacing it with character +// references. +// +// Invalid XML characters in str, if any, are stripped from the output. +// It is expected that most, if not all, of the text processed by this +// module will consist of ordinary English text. +// If this module is ever modified to produce version 1.1 XML output, +// most invalid characters can be retained using character references. +// TODO(wan): It might be nice to have a minimally invasive, human-readable +// escaping scheme for invalid characters, rather than dropping them. +String XmlUnitTestResultPrinter::EscapeXml(const char* str, bool is_attribute) { + Message m; + + if (str != NULL) { + for (const char* src = str; *src; ++src) { + switch (*src) { + case '<': + m << "<"; + break; + case '>': + m << ">"; + break; + case '&': + m << "&"; + break; + case '\'': + if (is_attribute) + m << "'"; + else + m << '\''; + break; + case '"': + if (is_attribute) + m << """; + else + m << '"'; + break; + default: + if (IsValidXmlCharacter(*src)) { + if (is_attribute && IsNormalizableWhitespace(*src)) + m << String::Format("&#x%02X;", unsigned(*src)); + else + m << *src; + } + break; + } + } + } + + return m.GetString(); +} + +// Returns the given string with all characters invalid in XML removed. +// Currently invalid characters are dropped from the string. An +// alternative is to replace them with certain characters such as . or ?. +string XmlUnitTestResultPrinter::RemoveInvalidXmlCharacters(const string& str) { + string output; + output.reserve(str.size()); + for (string::const_iterator it = str.begin(); it != str.end(); ++it) + if (IsValidXmlCharacter(*it)) + output.push_back(*it); + + return output; +} + +// The following routines generate an XML representation of a UnitTest +// object. +// +// This is how Google Test concepts map to the DTD: +// +// <-- corresponds to a UnitTest object +// <-- corresponds to a TestCase object +// <-- corresponds to a TestInfo object +// ... +// ... +// ... +// <-- individual assertion failures +// +// +// + +// Formats the given time in milliseconds as seconds. +std::string FormatTimeInMillisAsSeconds(TimeInMillis ms) { + ::std::stringstream ss; + ss << ms/1000.0; + return ss.str(); +} + +// Streams an XML CDATA section, escaping invalid CDATA sequences as needed. +void XmlUnitTestResultPrinter::OutputXmlCDataSection(::std::ostream* stream, + const char* data) { + const char* segment = data; + *stream << ""); + if (next_segment != NULL) { + stream->write( + segment, static_cast(next_segment - segment)); + *stream << "]]>]]>"); + } else { + *stream << segment; + break; + } + } + *stream << "]]>"; +} + +// Prints an XML representation of a TestInfo object. +// TODO(wan): There is also value in printing properties with the plain printer. +void XmlUnitTestResultPrinter::OutputXmlTestInfo(::std::ostream* stream, + const char* test_case_name, + const TestInfo& test_info) { + const TestResult& result = *test_info.result(); + *stream << " \n"; + *stream << " "; + const string location = internal::FormatCompilerIndependentFileLocation( + part.file_name(), part.line_number()); + const string message = location + "\n" + part.message(); + OutputXmlCDataSection(stream, + RemoveInvalidXmlCharacters(message).c_str()); + *stream << "\n"; + } + } + + if (failures == 0) + *stream << " />\n"; + else + *stream << " \n"; +} + +// Prints an XML representation of a TestCase object +void XmlUnitTestResultPrinter::PrintXmlTestCase(FILE* out, + const TestCase& test_case) { + fprintf(out, + " \n", + FormatTimeInMillisAsSeconds(test_case.elapsed_time()).c_str()); + for (int i = 0; i < test_case.total_test_count(); ++i) { + ::std::stringstream stream; + OutputXmlTestInfo(&stream, test_case.name(), *test_case.GetTestInfo(i)); + fprintf(out, "%s", StringStreamToString(&stream).c_str()); + } + fprintf(out, " \n"); +} + +// Prints an XML summary of unit_test to output stream out. +void XmlUnitTestResultPrinter::PrintXmlUnitTest(FILE* out, + const UnitTest& unit_test) { + fprintf(out, "\n"); + fprintf(out, + "\n"); + for (int i = 0; i < unit_test.total_test_case_count(); ++i) + PrintXmlTestCase(out, *unit_test.GetTestCase(i)); + fprintf(out, "\n"); +} + +// Produces a string representing the test properties in a result as space +// delimited XML attributes based on the property key="value" pairs. +String XmlUnitTestResultPrinter::TestPropertiesAsXmlAttributes( + const TestResult& result) { + Message attributes; + for (int i = 0; i < result.test_property_count(); ++i) { + const TestProperty& property = result.GetTestProperty(i); + attributes << " " << property.key() << "=" + << "\"" << EscapeXmlAttribute(property.value()) << "\""; + } + return attributes.GetString(); +} + +// End XmlUnitTestResultPrinter + +#if GTEST_CAN_STREAM_RESULTS_ + +// Streams test results to the given port on the given host machine. +class StreamingListener : public EmptyTestEventListener { + public: + // Escapes '=', '&', '%', and '\n' characters in str as "%xx". + static string UrlEncode(const char* str); + + StreamingListener(const string& host, const string& port) + : sockfd_(-1), host_name_(host), port_num_(port) { + MakeConnection(); + Send("gtest_streaming_protocol_version=1.0\n"); + } + + virtual ~StreamingListener() { + if (sockfd_ != -1) + CloseConnection(); + } + + void OnTestProgramStart(const UnitTest& /* unit_test */) { + Send("event=TestProgramStart\n"); + } + + void OnTestProgramEnd(const UnitTest& unit_test) { + // Note that Google Test current only report elapsed time for each + // test iteration, not for the entire test program. + Send(String::Format("event=TestProgramEnd&passed=%d\n", + unit_test.Passed())); + + // Notify the streaming server to stop. + CloseConnection(); + } + + void OnTestIterationStart(const UnitTest& /* unit_test */, int iteration) { + Send(String::Format("event=TestIterationStart&iteration=%d\n", + iteration)); + } + + void OnTestIterationEnd(const UnitTest& unit_test, int /* iteration */) { + Send(String::Format("event=TestIterationEnd&passed=%d&elapsed_time=%sms\n", + unit_test.Passed(), + StreamableToString(unit_test.elapsed_time()).c_str())); + } + + void OnTestCaseStart(const TestCase& test_case) { + Send(String::Format("event=TestCaseStart&name=%s\n", test_case.name())); + } + + void OnTestCaseEnd(const TestCase& test_case) { + Send(String::Format("event=TestCaseEnd&passed=%d&elapsed_time=%sms\n", + test_case.Passed(), + StreamableToString(test_case.elapsed_time()).c_str())); + } + + void OnTestStart(const TestInfo& test_info) { + Send(String::Format("event=TestStart&name=%s\n", test_info.name())); + } + + void OnTestEnd(const TestInfo& test_info) { + Send(String::Format( + "event=TestEnd&passed=%d&elapsed_time=%sms\n", + (test_info.result())->Passed(), + StreamableToString((test_info.result())->elapsed_time()).c_str())); + } + + void OnTestPartResult(const TestPartResult& test_part_result) { + const char* file_name = test_part_result.file_name(); + if (file_name == NULL) + file_name = ""; + Send(String::Format("event=TestPartResult&file=%s&line=%d&message=", + UrlEncode(file_name).c_str(), + test_part_result.line_number())); + Send(UrlEncode(test_part_result.message()) + "\n"); + } + + private: + // Creates a client socket and connects to the server. + void MakeConnection(); + + // Closes the socket. + void CloseConnection() { + GTEST_CHECK_(sockfd_ != -1) + << "CloseConnection() can be called only when there is a connection."; + + close(sockfd_); + sockfd_ = -1; + } + + // Sends a string to the socket. + void Send(const string& message) { + GTEST_CHECK_(sockfd_ != -1) + << "Send() can be called only when there is a connection."; + + const int len = static_cast(message.length()); + if (write(sockfd_, message.c_str(), len) != len) { + GTEST_LOG_(WARNING) + << "stream_result_to: failed to stream to " + << host_name_ << ":" << port_num_; + } + } + + int sockfd_; // socket file descriptor + const string host_name_; + const string port_num_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamingListener); +}; // class StreamingListener + +// Checks if str contains '=', '&', '%' or '\n' characters. If yes, +// replaces them by "%xx" where xx is their hexadecimal value. For +// example, replaces "=" with "%3D". This algorithm is O(strlen(str)) +// in both time and space -- important as the input str may contain an +// arbitrarily long test failure message and stack trace. +string StreamingListener::UrlEncode(const char* str) { + string result; + result.reserve(strlen(str) + 1); + for (char ch = *str; ch != '\0'; ch = *++str) { + switch (ch) { + case '%': + case '=': + case '&': + case '\n': + result.append(String::Format("%%%02x", static_cast(ch))); + break; + default: + result.push_back(ch); + break; + } + } + return result; +} + +void StreamingListener::MakeConnection() { + GTEST_CHECK_(sockfd_ == -1) + << "MakeConnection() can't be called when there is already a connection."; + + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; // To allow both IPv4 and IPv6 addresses. + hints.ai_socktype = SOCK_STREAM; + addrinfo* servinfo = NULL; + + // Use the getaddrinfo() to get a linked list of IP addresses for + // the given host name. + const int error_num = getaddrinfo( + host_name_.c_str(), port_num_.c_str(), &hints, &servinfo); + if (error_num != 0) { + GTEST_LOG_(WARNING) << "stream_result_to: getaddrinfo() failed: " + << gai_strerror(error_num); + } + + // Loop through all the results and connect to the first we can. + for (addrinfo* cur_addr = servinfo; sockfd_ == -1 && cur_addr != NULL; + cur_addr = cur_addr->ai_next) { + sockfd_ = socket( + cur_addr->ai_family, cur_addr->ai_socktype, cur_addr->ai_protocol); + if (sockfd_ != -1) { + // Connect the client socket to the server socket. + if (connect(sockfd_, cur_addr->ai_addr, cur_addr->ai_addrlen) == -1) { + close(sockfd_); + sockfd_ = -1; + } + } + } + + freeaddrinfo(servinfo); // all done with this structure + + if (sockfd_ == -1) { + GTEST_LOG_(WARNING) << "stream_result_to: failed to connect to " + << host_name_ << ":" << port_num_; + } +} + +// End of class Streaming Listener +#endif // GTEST_CAN_STREAM_RESULTS__ + +// Class ScopedTrace + +// Pushes the given source file location and message onto a per-thread +// trace stack maintained by Google Test. +// L < UnitTest::mutex_ +ScopedTrace::ScopedTrace(const char* file, int line, const Message& message) { + TraceInfo trace; + trace.file = file; + trace.line = line; + trace.message = message.GetString(); + + UnitTest::GetInstance()->PushGTestTrace(trace); +} + +// Pops the info pushed by the c'tor. +// L < UnitTest::mutex_ +ScopedTrace::~ScopedTrace() { + UnitTest::GetInstance()->PopGTestTrace(); +} + + +// class OsStackTraceGetter + +// Returns the current OS stack trace as a String. Parameters: +// +// max_depth - the maximum number of stack frames to be included +// in the trace. +// skip_count - the number of top frames to be skipped; doesn't count +// against max_depth. +// +// L < mutex_ +// We use "L < mutex_" to denote that the function may acquire mutex_. +String OsStackTraceGetter::CurrentStackTrace(int, int) { + return String(""); +} + +// L < mutex_ +void OsStackTraceGetter::UponLeavingGTest() { +} + +const char* const +OsStackTraceGetter::kElidedFramesMarker = + "... " GTEST_NAME_ " internal frames ..."; + +} // namespace internal + +// class TestEventListeners + +TestEventListeners::TestEventListeners() + : repeater_(new internal::TestEventRepeater()), + default_result_printer_(NULL), + default_xml_generator_(NULL) { +} + +TestEventListeners::~TestEventListeners() { delete repeater_; } + +// Returns the standard listener responsible for the default console +// output. Can be removed from the listeners list to shut down default +// console output. Note that removing this object from the listener list +// with Release transfers its ownership to the user. +void TestEventListeners::Append(TestEventListener* listener) { + repeater_->Append(listener); +} + +// Removes the given event listener from the list and returns it. It then +// becomes the caller's responsibility to delete the listener. Returns +// NULL if the listener is not found in the list. +TestEventListener* TestEventListeners::Release(TestEventListener* listener) { + if (listener == default_result_printer_) + default_result_printer_ = NULL; + else if (listener == default_xml_generator_) + default_xml_generator_ = NULL; + return repeater_->Release(listener); +} + +// Returns repeater that broadcasts the TestEventListener events to all +// subscribers. +TestEventListener* TestEventListeners::repeater() { return repeater_; } + +// Sets the default_result_printer attribute to the provided listener. +// The listener is also added to the listener list and previous +// default_result_printer is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultResultPrinter(TestEventListener* listener) { + if (default_result_printer_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_result_printer_); + default_result_printer_ = listener; + if (listener != NULL) + Append(listener); + } +} + +// Sets the default_xml_generator attribute to the provided listener. The +// listener is also added to the listener list and previous +// default_xml_generator is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultXmlGenerator(TestEventListener* listener) { + if (default_xml_generator_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_xml_generator_); + default_xml_generator_ = listener; + if (listener != NULL) + Append(listener); + } +} + +// Controls whether events will be forwarded by the repeater to the +// listeners in the list. +bool TestEventListeners::EventForwardingEnabled() const { + return repeater_->forwarding_enabled(); +} + +void TestEventListeners::SuppressEventForwarding() { + repeater_->set_forwarding_enabled(false); +} + +// class UnitTest + +// Gets the singleton UnitTest object. The first time this method is +// called, a UnitTest object is constructed and returned. Consecutive +// calls will return the same object. +// +// We don't protect this under mutex_ as a user is not supposed to +// call this before main() starts, from which point on the return +// value will never change. +UnitTest * UnitTest::GetInstance() { + // When compiled with MSVC 7.1 in optimized mode, destroying the + // UnitTest object upon exiting the program messes up the exit code, + // causing successful tests to appear failed. We have to use a + // different implementation in this case to bypass the compiler bug. + // This implementation makes the compiler happy, at the cost of + // leaking the UnitTest object. + + // CodeGear C++Builder insists on a public destructor for the + // default implementation. Use this implementation to keep good OO + // design with private destructor. + +#if (_MSC_VER == 1310 && !defined(_DEBUG)) || defined(__BORLANDC__) + static UnitTest* const instance = new UnitTest; + return instance; +#else + static UnitTest instance; + return &instance; +#endif // (_MSC_VER == 1310 && !defined(_DEBUG)) || defined(__BORLANDC__) +} + +// Gets the number of successful test cases. +int UnitTest::successful_test_case_count() const { + return impl()->successful_test_case_count(); +} + +// Gets the number of failed test cases. +int UnitTest::failed_test_case_count() const { + return impl()->failed_test_case_count(); +} + +// Gets the number of all test cases. +int UnitTest::total_test_case_count() const { + return impl()->total_test_case_count(); +} + +// Gets the number of all test cases that contain at least one test +// that should run. +int UnitTest::test_case_to_run_count() const { + return impl()->test_case_to_run_count(); +} + +// Gets the number of successful tests. +int UnitTest::successful_test_count() const { + return impl()->successful_test_count(); +} + +// Gets the number of failed tests. +int UnitTest::failed_test_count() const { return impl()->failed_test_count(); } + +// Gets the number of disabled tests. +int UnitTest::disabled_test_count() const { + return impl()->disabled_test_count(); +} + +// Gets the number of all tests. +int UnitTest::total_test_count() const { return impl()->total_test_count(); } + +// Gets the number of tests that should run. +int UnitTest::test_to_run_count() const { return impl()->test_to_run_count(); } + +// Gets the elapsed time, in milliseconds. +internal::TimeInMillis UnitTest::elapsed_time() const { + return impl()->elapsed_time(); +} + +// Returns true iff the unit test passed (i.e. all test cases passed). +bool UnitTest::Passed() const { return impl()->Passed(); } + +// Returns true iff the unit test failed (i.e. some test case failed +// or something outside of all tests failed). +bool UnitTest::Failed() const { return impl()->Failed(); } + +// Gets the i-th test case among all the test cases. i can range from 0 to +// total_test_case_count() - 1. If i is not in that range, returns NULL. +const TestCase* UnitTest::GetTestCase(int i) const { + return impl()->GetTestCase(i); +} + +// Gets the i-th test case among all the test cases. i can range from 0 to +// total_test_case_count() - 1. If i is not in that range, returns NULL. +TestCase* UnitTest::GetMutableTestCase(int i) { + return impl()->GetMutableTestCase(i); +} + +// Returns the list of event listeners that can be used to track events +// inside Google Test. +TestEventListeners& UnitTest::listeners() { + return *impl()->listeners(); +} + +// Registers and returns a global test environment. When a test +// program is run, all global test environments will be set-up in the +// order they were registered. After all tests in the program have +// finished, all global test environments will be torn-down in the +// *reverse* order they were registered. +// +// The UnitTest object takes ownership of the given environment. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +Environment* UnitTest::AddEnvironment(Environment* env) { + if (env == NULL) { + return NULL; + } + + impl_->environments().push_back(env); + return env; +} + +// Adds a TestPartResult to the current TestResult object. All Google Test +// assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) eventually call +// this to report their results. The user code should use the +// assertion macros instead of calling this directly. +// L < mutex_ +void UnitTest::AddTestPartResult(TestPartResult::Type result_type, + const char* file_name, + int line_number, + const internal::String& message, + const internal::String& os_stack_trace) { + Message msg; + msg << message; + + internal::MutexLock lock(&mutex_); + if (impl_->gtest_trace_stack().size() > 0) { + msg << "\n" << GTEST_NAME_ << " trace:"; + + for (int i = static_cast(impl_->gtest_trace_stack().size()); + i > 0; --i) { + const internal::TraceInfo& trace = impl_->gtest_trace_stack()[i - 1]; + msg << "\n" << internal::FormatFileLocation(trace.file, trace.line) + << " " << trace.message; + } + } + + if (os_stack_trace.c_str() != NULL && !os_stack_trace.empty()) { + msg << internal::kStackTraceMarker << os_stack_trace; + } + + const TestPartResult result = + TestPartResult(result_type, file_name, line_number, + msg.GetString().c_str()); + impl_->GetTestPartResultReporterForCurrentThread()-> + ReportTestPartResult(result); + + if (result_type != TestPartResult::kSuccess) { + // gtest_break_on_failure takes precedence over + // gtest_throw_on_failure. This allows a user to set the latter + // in the code (perhaps in order to use Google Test assertions + // with another testing framework) and specify the former on the + // command line for debugging. + if (GTEST_FLAG(break_on_failure)) { +#if GTEST_OS_WINDOWS + // Using DebugBreak on Windows allows gtest to still break into a debugger + // when a failure happens and both the --gtest_break_on_failure and + // the --gtest_catch_exceptions flags are specified. + DebugBreak(); +#else + // Dereference NULL through a volatile pointer to prevent the compiler + // from removing. We use this rather than abort() or __builtin_trap() for + // portability: Symbian doesn't implement abort() well, and some debuggers + // don't correctly trap abort(). + *static_cast(NULL) = 1; +#endif // GTEST_OS_WINDOWS + } else if (GTEST_FLAG(throw_on_failure)) { +#if GTEST_HAS_EXCEPTIONS + throw GoogleTestFailureException(result); +#else + // We cannot call abort() as it generates a pop-up in debug mode + // that cannot be suppressed in VC 7.1 or below. + exit(1); +#endif + } + } +} + +// Creates and adds a property to the current TestResult. If a property matching +// the supplied value already exists, updates its value instead. +void UnitTest::RecordPropertyForCurrentTest(const char* key, + const char* value) { + const TestProperty test_property(key, value); + impl_->current_test_result()->RecordProperty(test_property); +} + +// Runs all tests in this UnitTest object and prints the result. +// Returns 0 if successful, or 1 otherwise. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +int UnitTest::Run() { + // Captures the value of GTEST_FLAG(catch_exceptions). This value will be + // used for the duration of the program. + impl()->set_catch_exceptions(GTEST_FLAG(catch_exceptions)); + +#if GTEST_HAS_SEH + const bool in_death_test_child_process = + internal::GTEST_FLAG(internal_run_death_test).length() > 0; + + // Either the user wants Google Test to catch exceptions thrown by the + // tests or this is executing in the context of death test child + // process. In either case the user does not want to see pop-up dialogs + // about crashes - they are expected. + if (impl()->catch_exceptions() || in_death_test_child_process) { + +# if !GTEST_OS_WINDOWS_MOBILE + // SetErrorMode doesn't exist on CE. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); +# endif // !GTEST_OS_WINDOWS_MOBILE + +# if (defined(_MSC_VER) || GTEST_OS_WINDOWS_MINGW) && !GTEST_OS_WINDOWS_MOBILE + // Death test children can be terminated with _abort(). On Windows, + // _abort() can show a dialog with a warning message. This forces the + // abort message to go to stderr instead. + _set_error_mode(_OUT_TO_STDERR); +# endif + +# if _MSC_VER >= 1400 && !GTEST_OS_WINDOWS_MOBILE + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program. We need to suppress + // this dialog or it will pop up for every EXPECT/ASSERT_DEATH statement + // executed. Google Test will notify the user of any unexpected + // failure via stderr. + // + // VC++ doesn't define _set_abort_behavior() prior to the version 8.0. + // Users of prior VC versions shall suffer the agony and pain of + // clicking through the countless debug dialogs. + // TODO(vladl@google.com): find a way to suppress the abort dialog() in the + // debug mode when compiled with VC 7.1 or lower. + if (!GTEST_FLAG(break_on_failure)) + _set_abort_behavior( + 0x0, // Clear the following flags: + _WRITE_ABORT_MSG | _CALL_REPORTFAULT); // pop-up window, core dump. +# endif + + } +#endif // GTEST_HAS_SEH + + return internal::HandleExceptionsInMethodIfSupported( + impl(), + &internal::UnitTestImpl::RunAllTests, + "auxiliary test code (environments or event listeners)") ? 0 : 1; +} + +// Returns the working directory when the first TEST() or TEST_F() was +// executed. +const char* UnitTest::original_working_dir() const { + return impl_->original_working_dir_.c_str(); +} + +// Returns the TestCase object for the test that's currently running, +// or NULL if no test is running. +// L < mutex_ +const TestCase* UnitTest::current_test_case() const { + internal::MutexLock lock(&mutex_); + return impl_->current_test_case(); +} + +// Returns the TestInfo object for the test that's currently running, +// or NULL if no test is running. +// L < mutex_ +const TestInfo* UnitTest::current_test_info() const { + internal::MutexLock lock(&mutex_); + return impl_->current_test_info(); +} + +// Returns the random seed used at the start of the current test run. +int UnitTest::random_seed() const { return impl_->random_seed(); } + +#if GTEST_HAS_PARAM_TEST +// Returns ParameterizedTestCaseRegistry object used to keep track of +// value-parameterized tests and instantiate and register them. +// L < mutex_ +internal::ParameterizedTestCaseRegistry& + UnitTest::parameterized_test_registry() { + return impl_->parameterized_test_registry(); +} +#endif // GTEST_HAS_PARAM_TEST + +// Creates an empty UnitTest. +UnitTest::UnitTest() { + impl_ = new internal::UnitTestImpl(this); +} + +// Destructor of UnitTest. +UnitTest::~UnitTest() { + delete impl_; +} + +// Pushes a trace defined by SCOPED_TRACE() on to the per-thread +// Google Test trace stack. +// L < mutex_ +void UnitTest::PushGTestTrace(const internal::TraceInfo& trace) { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().push_back(trace); +} + +// Pops a trace from the per-thread Google Test trace stack. +// L < mutex_ +void UnitTest::PopGTestTrace() { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().pop_back(); +} + +namespace internal { + +UnitTestImpl::UnitTestImpl(UnitTest* parent) + : parent_(parent), +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4355) // Temporarily disables warning 4355 + // (using this in initializer). + default_global_test_part_result_reporter_(this), + default_per_thread_test_part_result_reporter_(this), +# pragma warning(pop) // Restores the warning state again. +#else + default_global_test_part_result_reporter_(this), + default_per_thread_test_part_result_reporter_(this), +#endif // _MSC_VER + global_test_part_result_repoter_( + &default_global_test_part_result_reporter_), + per_thread_test_part_result_reporter_( + &default_per_thread_test_part_result_reporter_), +#if GTEST_HAS_PARAM_TEST + parameterized_test_registry_(), + parameterized_tests_registered_(false), +#endif // GTEST_HAS_PARAM_TEST + last_death_test_case_(-1), + current_test_case_(NULL), + current_test_info_(NULL), + ad_hoc_test_result_(), + os_stack_trace_getter_(NULL), + post_flag_parse_init_performed_(false), + random_seed_(0), // Will be overridden by the flag before first use. + random_(0), // Will be reseeded before first use. + elapsed_time_(0), +#if GTEST_HAS_DEATH_TEST + internal_run_death_test_flag_(NULL), + death_test_factory_(new DefaultDeathTestFactory), +#endif + // Will be overridden by the flag before first use. + catch_exceptions_(false) { + listeners()->SetDefaultResultPrinter(new PrettyUnitTestResultPrinter); +} + +UnitTestImpl::~UnitTestImpl() { + // Deletes every TestCase. + ForEach(test_cases_, internal::Delete); + + // Deletes every Environment. + ForEach(environments_, internal::Delete); + + delete os_stack_trace_getter_; +} + +#if GTEST_HAS_DEATH_TEST +// Disables event forwarding if the control is currently in a death test +// subprocess. Must not be called before InitGoogleTest. +void UnitTestImpl::SuppressTestEventsIfInSubprocess() { + if (internal_run_death_test_flag_.get() != NULL) + listeners()->SuppressEventForwarding(); +} +#endif // GTEST_HAS_DEATH_TEST + +// Initializes event listeners performing XML output as specified by +// UnitTestOptions. Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureXmlOutput() { + const String& output_format = UnitTestOptions::GetOutputFormat(); + if (output_format == "xml") { + listeners()->SetDefaultXmlGenerator(new XmlUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); + } else if (output_format != "") { + printf("WARNING: unrecognized output format \"%s\" ignored.\n", + output_format.c_str()); + fflush(stdout); + } +} + +#if GTEST_CAN_STREAM_RESULTS_ +// Initializes event listeners for streaming test results in String form. +// Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureStreamingOutput() { + const string& target = GTEST_FLAG(stream_result_to); + if (!target.empty()) { + const size_t pos = target.find(':'); + if (pos != string::npos) { + listeners()->Append(new StreamingListener(target.substr(0, pos), + target.substr(pos+1))); + } else { + printf("WARNING: unrecognized streaming target \"%s\" ignored.\n", + target.c_str()); + fflush(stdout); + } + } +} +#endif // GTEST_CAN_STREAM_RESULTS_ + +// Performs initialization dependent upon flag values obtained in +// ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to +// ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest +// this function is also called from RunAllTests. Since this function can be +// called more than once, it has to be idempotent. +void UnitTestImpl::PostFlagParsingInit() { + // Ensures that this function does not execute more than once. + if (!post_flag_parse_init_performed_) { + post_flag_parse_init_performed_ = true; + +#if GTEST_HAS_DEATH_TEST + InitDeathTestSubprocessControlInfo(); + SuppressTestEventsIfInSubprocess(); +#endif // GTEST_HAS_DEATH_TEST + + // Registers parameterized tests. This makes parameterized tests + // available to the UnitTest reflection API without running + // RUN_ALL_TESTS. + RegisterParameterizedTests(); + + // Configures listeners for XML output. This makes it possible for users + // to shut down the default XML output before invoking RUN_ALL_TESTS. + ConfigureXmlOutput(); + +#if GTEST_CAN_STREAM_RESULTS_ + // Configures listeners for streaming test results to the specified server. + ConfigureStreamingOutput(); +#endif // GTEST_CAN_STREAM_RESULTS_ + } +} + +// A predicate that checks the name of a TestCase against a known +// value. +// +// This is used for implementation of the UnitTest class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestCaseNameIs is copyable. +class TestCaseNameIs { + public: + // Constructor. + explicit TestCaseNameIs(const String& name) + : name_(name) {} + + // Returns true iff the name of test_case matches name_. + bool operator()(const TestCase* test_case) const { + return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0; + } + + private: + String name_; +}; + +// Finds and returns a TestCase with the given name. If one doesn't +// exist, creates one and returns it. It's the CALLER'S +// RESPONSIBILITY to ensure that this function is only called WHEN THE +// TESTS ARE NOT SHUFFLED. +// +// Arguments: +// +// test_case_name: name of the test case +// type_param: the name of the test case's type parameter, or NULL if +// this is not a typed or a type-parameterized test case. +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +TestCase* UnitTestImpl::GetTestCase(const char* test_case_name, + const char* type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc) { + // Can we find a TestCase with the given name? + const std::vector::const_iterator test_case = + std::find_if(test_cases_.begin(), test_cases_.end(), + TestCaseNameIs(test_case_name)); + + if (test_case != test_cases_.end()) + return *test_case; + + // No. Let's create one. + TestCase* const new_test_case = + new TestCase(test_case_name, type_param, set_up_tc, tear_down_tc); + + // Is this a death test case? + if (internal::UnitTestOptions::MatchesFilter(String(test_case_name), + kDeathTestCaseFilter)) { + // Yes. Inserts the test case after the last death test case + // defined so far. This only works when the test cases haven't + // been shuffled. Otherwise we may end up running a death test + // after a non-death test. + ++last_death_test_case_; + test_cases_.insert(test_cases_.begin() + last_death_test_case_, + new_test_case); + } else { + // No. Appends to the end of the list. + test_cases_.push_back(new_test_case); + } + + test_case_indices_.push_back(static_cast(test_case_indices_.size())); + return new_test_case; +} + +// Helpers for setting up / tearing down the given environment. They +// are for use in the ForEach() function. +static void SetUpEnvironment(Environment* env) { env->SetUp(); } +static void TearDownEnvironment(Environment* env) { env->TearDown(); } + +// Runs all tests in this UnitTest object, prints the result, and +// returns true if all tests are successful. If any exception is +// thrown during a test, the test is considered to be failed, but the +// rest of the tests will still be run. +// +// When parameterized tests are enabled, it expands and registers +// parameterized tests first in RegisterParameterizedTests(). +// All other functions called from RunAllTests() may safely assume that +// parameterized tests are ready to be counted and run. +bool UnitTestImpl::RunAllTests() { + // Makes sure InitGoogleTest() was called. + if (!GTestIsInitialized()) { + printf("%s", + "\nThis test program did NOT call ::testing::InitGoogleTest " + "before calling RUN_ALL_TESTS(). Please fix it.\n"); + return false; + } + + // Do not run any test if the --help flag was specified. + if (g_help_flag) + return true; + + // Repeats the call to the post-flag parsing initialization in case the + // user didn't call InitGoogleTest. + PostFlagParsingInit(); + + // Even if sharding is not on, test runners may want to use the + // GTEST_SHARD_STATUS_FILE to query whether the test supports the sharding + // protocol. + internal::WriteToShardStatusFileIfNeeded(); + + // True iff we are in a subprocess for running a thread-safe-style + // death test. + bool in_subprocess_for_death_test = false; + +#if GTEST_HAS_DEATH_TEST + in_subprocess_for_death_test = (internal_run_death_test_flag_.get() != NULL); +#endif // GTEST_HAS_DEATH_TEST + + const bool should_shard = ShouldShard(kTestTotalShards, kTestShardIndex, + in_subprocess_for_death_test); + + // Compares the full test names with the filter to decide which + // tests to run. + const bool has_tests_to_run = FilterTests(should_shard + ? HONOR_SHARDING_PROTOCOL + : IGNORE_SHARDING_PROTOCOL) > 0; + + // Lists the tests and exits if the --gtest_list_tests flag was specified. + if (GTEST_FLAG(list_tests)) { + // This must be called *after* FilterTests() has been called. + ListTestsMatchingFilter(); + return true; + } + + random_seed_ = GTEST_FLAG(shuffle) ? + GetRandomSeedFromFlag(GTEST_FLAG(random_seed)) : 0; + + // True iff at least one test has failed. + bool failed = false; + + TestEventListener* repeater = listeners()->repeater(); + + repeater->OnTestProgramStart(*parent_); + + // How many times to repeat the tests? We don't want to repeat them + // when we are inside the subprocess of a death test. + const int repeat = in_subprocess_for_death_test ? 1 : GTEST_FLAG(repeat); + // Repeats forever if the repeat count is negative. + const bool forever = repeat < 0; + for (int i = 0; forever || i != repeat; i++) { + // We want to preserve failures generated by ad-hoc test + // assertions executed before RUN_ALL_TESTS(). + ClearNonAdHocTestResult(); + + const TimeInMillis start = GetTimeInMillis(); + + // Shuffles test cases and tests if requested. + if (has_tests_to_run && GTEST_FLAG(shuffle)) { + random()->Reseed(random_seed_); + // This should be done before calling OnTestIterationStart(), + // such that a test event listener can see the actual test order + // in the event. + ShuffleTests(); + } + + // Tells the unit test event listeners that the tests are about to start. + repeater->OnTestIterationStart(*parent_, i); + + // Runs each test case if there is at least one test to run. + if (has_tests_to_run) { + // Sets up all environments beforehand. + repeater->OnEnvironmentsSetUpStart(*parent_); + ForEach(environments_, SetUpEnvironment); + repeater->OnEnvironmentsSetUpEnd(*parent_); + + // Runs the tests only if there was no fatal failure during global + // set-up. + if (!Test::HasFatalFailure()) { + for (int test_index = 0; test_index < total_test_case_count(); + test_index++) { + GetMutableTestCase(test_index)->Run(); + } + } + + // Tears down all environments in reverse order afterwards. + repeater->OnEnvironmentsTearDownStart(*parent_); + std::for_each(environments_.rbegin(), environments_.rend(), + TearDownEnvironment); + repeater->OnEnvironmentsTearDownEnd(*parent_); + } + + elapsed_time_ = GetTimeInMillis() - start; + + // Tells the unit test event listener that the tests have just finished. + repeater->OnTestIterationEnd(*parent_, i); + + // Gets the result and clears it. + if (!Passed()) { + failed = true; + } + + // Restores the original test order after the iteration. This + // allows the user to quickly repro a failure that happens in the + // N-th iteration without repeating the first (N - 1) iterations. + // This is not enclosed in "if (GTEST_FLAG(shuffle)) { ... }", in + // case the user somehow changes the value of the flag somewhere + // (it's always safe to unshuffle the tests). + UnshuffleTests(); + + if (GTEST_FLAG(shuffle)) { + // Picks a new random seed for each iteration. + random_seed_ = GetNextRandomSeed(random_seed_); + } + } + + repeater->OnTestProgramEnd(*parent_); + + return !failed; +} + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded() { + const char* const test_shard_file = posix::GetEnv(kTestShardStatusFile); + if (test_shard_file != NULL) { + FILE* const file = posix::FOpen(test_shard_file, "w"); + if (file == NULL) { + ColoredPrintf(COLOR_RED, + "Could not write to the test shard status file \"%s\" " + "specified by the %s environment variable.\n", + test_shard_file, kTestShardStatusFile); + fflush(stdout); + exit(EXIT_FAILURE); + } + fclose(file); + } +} + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (i.e., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +bool ShouldShard(const char* total_shards_env, + const char* shard_index_env, + bool in_subprocess_for_death_test) { + if (in_subprocess_for_death_test) { + return false; + } + + const Int32 total_shards = Int32FromEnvOrDie(total_shards_env, -1); + const Int32 shard_index = Int32FromEnvOrDie(shard_index_env, -1); + + if (total_shards == -1 && shard_index == -1) { + return false; + } else if (total_shards == -1 && shard_index != -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestShardIndex << " = " << shard_index + << ", but have left " << kTestTotalShards << " unset.\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (total_shards != -1 && shard_index == -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestTotalShards << " = " << total_shards + << ", but have left " << kTestShardIndex << " unset.\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (shard_index < 0 || shard_index >= total_shards) { + const Message msg = Message() + << "Invalid environment variables: we require 0 <= " + << kTestShardIndex << " < " << kTestTotalShards + << ", but you have " << kTestShardIndex << "=" << shard_index + << ", " << kTestTotalShards << "=" << total_shards << ".\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } + + return total_shards > 1; +} + +// Parses the environment variable var as an Int32. If it is unset, +// returns default_val. If it is not an Int32, prints an error +// and aborts. +Int32 Int32FromEnvOrDie(const char* var, Int32 default_val) { + const char* str_val = posix::GetEnv(var); + if (str_val == NULL) { + return default_val; + } + + Int32 result; + if (!ParseInt32(Message() << "The value of environment variable " << var, + str_val, &result)) { + exit(EXIT_FAILURE); + } + return result; +} + +// Given the total number of shards, the shard index, and the test id, +// returns true iff the test should be run on this shard. The test id is +// some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +bool ShouldRunTestOnShard(int total_shards, int shard_index, int test_id) { + return (test_id % total_shards) == shard_index; +} + +// Compares the name of each test with the user-specified filter to +// decide whether the test should be run, then records the result in +// each TestCase and TestInfo object. +// If shard_tests == true, further filters tests based on sharding +// variables in the environment - see +// http://code.google.com/p/googletest/wiki/GoogleTestAdvancedGuide. +// Returns the number of tests that should run. +int UnitTestImpl::FilterTests(ReactionToSharding shard_tests) { + const Int32 total_shards = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestTotalShards, -1) : -1; + const Int32 shard_index = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestShardIndex, -1) : -1; + + // num_runnable_tests are the number of tests that will + // run across all shards (i.e., match filter and are not disabled). + // num_selected_tests are the number of tests to be run on + // this shard. + int num_runnable_tests = 0; + int num_selected_tests = 0; + for (size_t i = 0; i < test_cases_.size(); i++) { + TestCase* const test_case = test_cases_[i]; + const String &test_case_name = test_case->name(); + test_case->set_should_run(false); + + for (size_t j = 0; j < test_case->test_info_list().size(); j++) { + TestInfo* const test_info = test_case->test_info_list()[j]; + const String test_name(test_info->name()); + // A test is disabled if test case name or test name matches + // kDisableTestFilter. + const bool is_disabled = + internal::UnitTestOptions::MatchesFilter(test_case_name, + kDisableTestFilter) || + internal::UnitTestOptions::MatchesFilter(test_name, + kDisableTestFilter); + test_info->is_disabled_ = is_disabled; + + const bool matches_filter = + internal::UnitTestOptions::FilterMatchesTest(test_case_name, + test_name); + test_info->matches_filter_ = matches_filter; + + const bool is_runnable = + (GTEST_FLAG(also_run_disabled_tests) || !is_disabled) && + matches_filter; + + const bool is_selected = is_runnable && + (shard_tests == IGNORE_SHARDING_PROTOCOL || + ShouldRunTestOnShard(total_shards, shard_index, + num_runnable_tests)); + + num_runnable_tests += is_runnable; + num_selected_tests += is_selected; + + test_info->should_run_ = is_selected; + test_case->set_should_run(test_case->should_run() || is_selected); + } + } + return num_selected_tests; +} + +// Prints the names of the tests matching the user-specified filter flag. +void UnitTestImpl::ListTestsMatchingFilter() { + for (size_t i = 0; i < test_cases_.size(); i++) { + const TestCase* const test_case = test_cases_[i]; + bool printed_test_case_name = false; + + for (size_t j = 0; j < test_case->test_info_list().size(); j++) { + const TestInfo* const test_info = + test_case->test_info_list()[j]; + if (test_info->matches_filter_) { + if (!printed_test_case_name) { + printed_test_case_name = true; + printf("%s.\n", test_case->name()); + } + printf(" %s\n", test_info->name()); + } + } + } + fflush(stdout); +} + +// Sets the OS stack trace getter. +// +// Does nothing if the input and the current OS stack trace getter are +// the same; otherwise, deletes the old getter and makes the input the +// current getter. +void UnitTestImpl::set_os_stack_trace_getter( + OsStackTraceGetterInterface* getter) { + if (os_stack_trace_getter_ != getter) { + delete os_stack_trace_getter_; + os_stack_trace_getter_ = getter; + } +} + +// Returns the current OS stack trace getter if it is not NULL; +// otherwise, creates an OsStackTraceGetter, makes it the current +// getter, and returns it. +OsStackTraceGetterInterface* UnitTestImpl::os_stack_trace_getter() { + if (os_stack_trace_getter_ == NULL) { + os_stack_trace_getter_ = new OsStackTraceGetter; + } + + return os_stack_trace_getter_; +} + +// Returns the TestResult for the test that's currently running, or +// the TestResult for the ad hoc test if no test is running. +TestResult* UnitTestImpl::current_test_result() { + return current_test_info_ ? + &(current_test_info_->result_) : &ad_hoc_test_result_; +} + +// Shuffles all test cases, and the tests within each test case, +// making sure that death tests are still run first. +void UnitTestImpl::ShuffleTests() { + // Shuffles the death test cases. + ShuffleRange(random(), 0, last_death_test_case_ + 1, &test_case_indices_); + + // Shuffles the non-death test cases. + ShuffleRange(random(), last_death_test_case_ + 1, + static_cast(test_cases_.size()), &test_case_indices_); + + // Shuffles the tests inside each test case. + for (size_t i = 0; i < test_cases_.size(); i++) { + test_cases_[i]->ShuffleTests(random()); + } +} + +// Restores the test cases and tests to their order before the first shuffle. +void UnitTestImpl::UnshuffleTests() { + for (size_t i = 0; i < test_cases_.size(); i++) { + // Unshuffles the tests in each test case. + test_cases_[i]->UnshuffleTests(); + // Resets the index of each test case. + test_case_indices_[i] = static_cast(i); + } +} + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +String GetCurrentOsStackTraceExceptTop(UnitTest* /*unit_test*/, + int skip_count) { + // We pass skip_count + 1 to skip this wrapper function in addition + // to what the user really wants to skip. + return GetUnitTestImpl()->CurrentOsStackTraceExceptTop(skip_count + 1); +} + +// Used by the GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_ macro to +// suppress unreachable code warnings. +namespace { +class ClassUniqueToAlwaysTrue {}; +} + +bool IsTrue(bool condition) { return condition; } + +bool AlwaysTrue() { +#if GTEST_HAS_EXCEPTIONS + // This condition is always false so AlwaysTrue() never actually throws, + // but it makes the compiler think that it may throw. + if (IsTrue(false)) + throw ClassUniqueToAlwaysTrue(); +#endif // GTEST_HAS_EXCEPTIONS + return true; +} + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +bool SkipPrefix(const char* prefix, const char** pstr) { + const size_t prefix_len = strlen(prefix); + if (strncmp(*pstr, prefix, prefix_len) == 0) { + *pstr += prefix_len; + return true; + } + return false; +} + +// Parses a string as a command line flag. The string should have +// the format "--flag=value". When def_optional is true, the "=value" +// part can be omitted. +// +// Returns the value of the flag, or NULL if the parsing failed. +const char* ParseFlagValue(const char* str, + const char* flag, + bool def_optional) { + // str and flag must not be NULL. + if (str == NULL || flag == NULL) return NULL; + + // The flag must start with "--" followed by GTEST_FLAG_PREFIX_. + const String flag_str = String::Format("--%s%s", GTEST_FLAG_PREFIX_, flag); + const size_t flag_len = flag_str.length(); + if (strncmp(str, flag_str.c_str(), flag_len) != 0) return NULL; + + // Skips the flag name. + const char* flag_end = str + flag_len; + + // When def_optional is true, it's OK to not have a "=value" part. + if (def_optional && (flag_end[0] == '\0')) { + return flag_end; + } + + // If def_optional is true and there are more characters after the + // flag name, or if def_optional is false, there must be a '=' after + // the flag name. + if (flag_end[0] != '=') return NULL; + + // Returns the string after "=". + return flag_end + 1; +} + +// Parses a string for a bool flag, in the form of either +// "--flag=value" or "--flag". +// +// In the former case, the value is taken as true as long as it does +// not start with '0', 'f', or 'F'. +// +// In the latter case, the value is taken as true. +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseBoolFlag(const char* str, const char* flag, bool* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, true); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Converts the string value to a bool. + *value = !(*value_str == '0' || *value_str == 'f' || *value_str == 'F'); + return true; +} + +// Parses a string for an Int32 flag, in the form of +// "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseInt32Flag(const char* str, const char* flag, Int32* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Sets *value to the value of the flag. + return ParseInt32(Message() << "The value of flag --" << flag, + value_str, value); +} + +// Parses a string for a string flag, in the form of +// "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseStringFlag(const char* str, const char* flag, String* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Sets *value to the value of the flag. + *value = value_str; + return true; +} + +// Determines whether a string has a prefix that Google Test uses for its +// flags, i.e., starts with GTEST_FLAG_PREFIX_ or GTEST_FLAG_PREFIX_DASH_. +// If Google Test detects that a command line flag has its prefix but is not +// recognized, it will print its help message. Flags starting with +// GTEST_INTERNAL_PREFIX_ followed by "internal_" are considered Google Test +// internal flags and do not trigger the help message. +static bool HasGoogleTestFlagPrefix(const char* str) { + return (SkipPrefix("--", &str) || + SkipPrefix("-", &str) || + SkipPrefix("/", &str)) && + !SkipPrefix(GTEST_FLAG_PREFIX_ "internal_", &str) && + (SkipPrefix(GTEST_FLAG_PREFIX_, &str) || + SkipPrefix(GTEST_FLAG_PREFIX_DASH_, &str)); +} + +// Prints a string containing code-encoded text. The following escape +// sequences can be used in the string to control the text color: +// +// @@ prints a single '@' character. +// @R changes the color to red. +// @G changes the color to green. +// @Y changes the color to yellow. +// @D changes to the default terminal text color. +// +// TODO(wan@google.com): Write tests for this once we add stdout +// capturing to Google Test. +static void PrintColorEncoded(const char* str) { + GTestColor color = COLOR_DEFAULT; // The current color. + + // Conceptually, we split the string into segments divided by escape + // sequences. Then we print one segment at a time. At the end of + // each iteration, the str pointer advances to the beginning of the + // next segment. + for (;;) { + const char* p = strchr(str, '@'); + if (p == NULL) { + ColoredPrintf(color, "%s", str); + return; + } + + ColoredPrintf(color, "%s", String(str, p - str).c_str()); + + const char ch = p[1]; + str = p + 2; + if (ch == '@') { + ColoredPrintf(color, "@"); + } else if (ch == 'D') { + color = COLOR_DEFAULT; + } else if (ch == 'R') { + color = COLOR_RED; + } else if (ch == 'G') { + color = COLOR_GREEN; + } else if (ch == 'Y') { + color = COLOR_YELLOW; + } else { + --str; + } + } +} + +static const char kColorEncodedHelpMessage[] = +"This program contains tests written using " GTEST_NAME_ ". You can use the\n" +"following command line flags to control its behavior:\n" +"\n" +"Test Selection:\n" +" @G--" GTEST_FLAG_PREFIX_ "list_tests@D\n" +" List the names of all tests instead of running them. The name of\n" +" TEST(Foo, Bar) is \"Foo.Bar\".\n" +" @G--" GTEST_FLAG_PREFIX_ "filter=@YPOSTIVE_PATTERNS" + "[@G-@YNEGATIVE_PATTERNS]@D\n" +" Run only the tests whose name matches one of the positive patterns but\n" +" none of the negative patterns. '?' matches any single character; '*'\n" +" matches any substring; ':' separates two patterns.\n" +" @G--" GTEST_FLAG_PREFIX_ "also_run_disabled_tests@D\n" +" Run all disabled tests too.\n" +"\n" +"Test Execution:\n" +" @G--" GTEST_FLAG_PREFIX_ "repeat=@Y[COUNT]@D\n" +" Run the tests repeatedly; use a negative count to repeat forever.\n" +" @G--" GTEST_FLAG_PREFIX_ "shuffle@D\n" +" Randomize tests' orders on every iteration.\n" +" @G--" GTEST_FLAG_PREFIX_ "random_seed=@Y[NUMBER]@D\n" +" Random number seed to use for shuffling test orders (between 1 and\n" +" 99999, or 0 to use a seed based on the current time).\n" +"\n" +"Test Output:\n" +" @G--" GTEST_FLAG_PREFIX_ "color=@Y(@Gyes@Y|@Gno@Y|@Gauto@Y)@D\n" +" Enable/disable colored output. The default is @Gauto@D.\n" +" -@G-" GTEST_FLAG_PREFIX_ "print_time=0@D\n" +" Don't print the elapsed time of each test.\n" +" @G--" GTEST_FLAG_PREFIX_ "output=xml@Y[@G:@YDIRECTORY_PATH@G" + GTEST_PATH_SEP_ "@Y|@G:@YFILE_PATH]@D\n" +" Generate an XML report in the given directory or with the given file\n" +" name. @YFILE_PATH@D defaults to @Gtest_details.xml@D.\n" +#if GTEST_CAN_STREAM_RESULTS_ +" @G--" GTEST_FLAG_PREFIX_ "stream_result_to=@YHOST@G:@YPORT@D\n" +" Stream test results to the given server.\n" +#endif // GTEST_CAN_STREAM_RESULTS_ +"\n" +"Assertion Behavior:\n" +#if GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS +" @G--" GTEST_FLAG_PREFIX_ "death_test_style=@Y(@Gfast@Y|@Gthreadsafe@Y)@D\n" +" Set the default death test style.\n" +#endif // GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS +" @G--" GTEST_FLAG_PREFIX_ "break_on_failure@D\n" +" Turn assertion failures into debugger break-points.\n" +" @G--" GTEST_FLAG_PREFIX_ "throw_on_failure@D\n" +" Turn assertion failures into C++ exceptions.\n" +" @G--" GTEST_FLAG_PREFIX_ "catch_exceptions=0@D\n" +" Do not report exceptions as test failures. Instead, allow them\n" +" to crash the program or throw a pop-up (on Windows).\n" +"\n" +"Except for @G--" GTEST_FLAG_PREFIX_ "list_tests@D, you can alternatively set " + "the corresponding\n" +"environment variable of a flag (all letters in upper-case). For example, to\n" +"disable colored text output, you can either specify @G--" GTEST_FLAG_PREFIX_ + "color=no@D or set\n" +"the @G" GTEST_FLAG_PREFIX_UPPER_ "COLOR@D environment variable to @Gno@D.\n" +"\n" +"For more information, please read the " GTEST_NAME_ " documentation at\n" +"@G" GTEST_PROJECT_URL_ "@D. If you find a bug in " GTEST_NAME_ "\n" +"(not one in your own code or tests), please report it to\n" +"@G<" GTEST_DEV_EMAIL_ ">@D.\n"; + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. The type parameter CharType can be +// instantiated to either char or wchar_t. +template +void ParseGoogleTestFlagsOnlyImpl(int* argc, CharType** argv) { + for (int i = 1; i < *argc; i++) { + const String arg_string = StreamableToString(argv[i]); + const char* const arg = arg_string.c_str(); + + using internal::ParseBoolFlag; + using internal::ParseInt32Flag; + using internal::ParseStringFlag; + + // Do we see a Google Test flag? + if (ParseBoolFlag(arg, kAlsoRunDisabledTestsFlag, + >EST_FLAG(also_run_disabled_tests)) || + ParseBoolFlag(arg, kBreakOnFailureFlag, + >EST_FLAG(break_on_failure)) || + ParseBoolFlag(arg, kCatchExceptionsFlag, + >EST_FLAG(catch_exceptions)) || + ParseStringFlag(arg, kColorFlag, >EST_FLAG(color)) || + ParseStringFlag(arg, kDeathTestStyleFlag, + >EST_FLAG(death_test_style)) || + ParseBoolFlag(arg, kDeathTestUseFork, + >EST_FLAG(death_test_use_fork)) || + ParseStringFlag(arg, kFilterFlag, >EST_FLAG(filter)) || + ParseStringFlag(arg, kInternalRunDeathTestFlag, + >EST_FLAG(internal_run_death_test)) || + ParseBoolFlag(arg, kListTestsFlag, >EST_FLAG(list_tests)) || + ParseStringFlag(arg, kOutputFlag, >EST_FLAG(output)) || + ParseBoolFlag(arg, kPrintTimeFlag, >EST_FLAG(print_time)) || + ParseInt32Flag(arg, kRandomSeedFlag, >EST_FLAG(random_seed)) || + ParseInt32Flag(arg, kRepeatFlag, >EST_FLAG(repeat)) || + ParseBoolFlag(arg, kShuffleFlag, >EST_FLAG(shuffle)) || + ParseInt32Flag(arg, kStackTraceDepthFlag, + >EST_FLAG(stack_trace_depth)) || + ParseStringFlag(arg, kStreamResultToFlag, + >EST_FLAG(stream_result_to)) || + ParseBoolFlag(arg, kThrowOnFailureFlag, + >EST_FLAG(throw_on_failure)) + ) { + // Yes. Shift the remainder of the argv list left by one. Note + // that argv has (*argc + 1) elements, the last one always being + // NULL. The following loop moves the trailing NULL element as + // well. + for (int j = i; j != *argc; j++) { + argv[j] = argv[j + 1]; + } + + // Decrements the argument count. + (*argc)--; + + // We also need to decrement the iterator as we just removed + // an element. + i--; + } else if (arg_string == "--help" || arg_string == "-h" || + arg_string == "-?" || arg_string == "/?" || + HasGoogleTestFlagPrefix(arg)) { + // Both help flag and unrecognized Google Test flags (excluding + // internal ones) trigger help display. + g_help_flag = true; + } + } + + if (g_help_flag) { + // We print the help here instead of in RUN_ALL_TESTS(), as the + // latter may not be called at all if the user is using Google + // Test with another testing framework. + PrintColorEncoded(kColorEncodedHelpMessage); + } +} + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +void ParseGoogleTestFlagsOnly(int* argc, char** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); +} +void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); +} + +// The internal implementation of InitGoogleTest(). +// +// The type parameter CharType can be instantiated to either char or +// wchar_t. +template +void InitGoogleTestImpl(int* argc, CharType** argv) { + g_init_gtest_count++; + + // We don't want to run the initialization code twice. + if (g_init_gtest_count != 1) return; + + if (*argc <= 0) return; + + internal::g_executable_path = internal::StreamableToString(argv[0]); + +#if GTEST_HAS_DEATH_TEST + + g_argvs.clear(); + for (int i = 0; i != *argc; i++) { + g_argvs.push_back(StreamableToString(argv[i])); + } + +#endif // GTEST_HAS_DEATH_TEST + + ParseGoogleTestFlagsOnly(argc, argv); + GetUnitTestImpl()->PostFlagParsingInit(); +} + +} // namespace internal + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +void InitGoogleTest(int* argc, char** argv) { + internal::InitGoogleTestImpl(argc, argv); +} + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +void InitGoogleTest(int* argc, wchar_t** argv) { + internal::InitGoogleTestImpl(argc, argv); +} + +} // namespace testing +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan), vladl@google.com (Vlad Losev) +// +// This file implements death tests. + + +#if GTEST_HAS_DEATH_TEST + +# if GTEST_OS_MAC +# include +# endif // GTEST_OS_MAC + +# include +# include +# include +# include + +# if GTEST_OS_WINDOWS +# include +# else +# include +# include +# endif // GTEST_OS_WINDOWS + +#endif // GTEST_HAS_DEATH_TEST + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { + +// Constants. + +// The default death test style. +static const char kDefaultDeathTestStyle[] = "fast"; + +GTEST_DEFINE_string_( + death_test_style, + internal::StringFromGTestEnv("death_test_style", kDefaultDeathTestStyle), + "Indicates how to run a death test in a forked child process: " + "\"threadsafe\" (child process re-executes the test binary " + "from the beginning, running only the specific death test) or " + "\"fast\" (child process runs the death test immediately " + "after forking)."); + +GTEST_DEFINE_bool_( + death_test_use_fork, + internal::BoolFromGTestEnv("death_test_use_fork", false), + "Instructs to use fork()/_exit() instead of clone() in death tests. " + "Ignored and always uses fork() on POSIX systems where clone() is not " + "implemented. Useful when running under valgrind or similar tools if " + "those do not support clone(). Valgrind 3.3.1 will just fail if " + "it sees an unsupported combination of clone() flags. " + "It is not recommended to use this flag w/o valgrind though it will " + "work in 99% of the cases. Once valgrind is fixed, this flag will " + "most likely be removed."); + +namespace internal { +GTEST_DEFINE_string_( + internal_run_death_test, "", + "Indicates the file, line number, temporal index of " + "the single death test to run, and a file descriptor to " + "which a success code may be sent, all separated by " + "colons. This flag is specified if and only if the current " + "process is a sub-process launched for running a thread-safe " + "death test. FOR INTERNAL USE ONLY."); +} // namespace internal + +#if GTEST_HAS_DEATH_TEST + +// ExitedWithCode constructor. +ExitedWithCode::ExitedWithCode(int exit_code) : exit_code_(exit_code) { +} + +// ExitedWithCode function-call operator. +bool ExitedWithCode::operator()(int exit_status) const { +# if GTEST_OS_WINDOWS + + return exit_status == exit_code_; + +# else + + return WIFEXITED(exit_status) && WEXITSTATUS(exit_status) == exit_code_; + +# endif // GTEST_OS_WINDOWS +} + +# if !GTEST_OS_WINDOWS +// KilledBySignal constructor. +KilledBySignal::KilledBySignal(int signum) : signum_(signum) { +} + +// KilledBySignal function-call operator. +bool KilledBySignal::operator()(int exit_status) const { + return WIFSIGNALED(exit_status) && WTERMSIG(exit_status) == signum_; +} +# endif // !GTEST_OS_WINDOWS + +namespace internal { + +// Utilities needed for death tests. + +// Generates a textual description of a given exit code, in the format +// specified by wait(2). +static String ExitSummary(int exit_code) { + Message m; + +# if GTEST_OS_WINDOWS + + m << "Exited with exit status " << exit_code; + +# else + + if (WIFEXITED(exit_code)) { + m << "Exited with exit status " << WEXITSTATUS(exit_code); + } else if (WIFSIGNALED(exit_code)) { + m << "Terminated by signal " << WTERMSIG(exit_code); + } +# ifdef WCOREDUMP + if (WCOREDUMP(exit_code)) { + m << " (core dumped)"; + } +# endif +# endif // GTEST_OS_WINDOWS + + return m.GetString(); +} + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +bool ExitedUnsuccessfully(int exit_status) { + return !ExitedWithCode(0)(exit_status); +} + +# if !GTEST_OS_WINDOWS +// Generates a textual failure message when a death test finds more than +// one thread running, or cannot determine the number of threads, prior +// to executing the given statement. It is the responsibility of the +// caller not to pass a thread_count of 1. +static String DeathTestThreadWarning(size_t thread_count) { + Message msg; + msg << "Death tests use fork(), which is unsafe particularly" + << " in a threaded context. For this test, " << GTEST_NAME_ << " "; + if (thread_count == 0) + msg << "couldn't detect the number of threads."; + else + msg << "detected " << thread_count << " threads."; + return msg.GetString(); +} +# endif // !GTEST_OS_WINDOWS + +// Flag characters for reporting a death test that did not die. +static const char kDeathTestLived = 'L'; +static const char kDeathTestReturned = 'R'; +static const char kDeathTestThrew = 'T'; +static const char kDeathTestInternalError = 'I'; + +// An enumeration describing all of the possible ways that a death test can +// conclude. DIED means that the process died while executing the test +// code; LIVED means that process lived beyond the end of the test code; +// RETURNED means that the test statement attempted to execute a return +// statement, which is not allowed; THREW means that the test statement +// returned control by throwing an exception. IN_PROGRESS means the test +// has not yet concluded. +// TODO(vladl@google.com): Unify names and possibly values for +// AbortReason, DeathTestOutcome, and flag characters above. +enum DeathTestOutcome { IN_PROGRESS, DIED, LIVED, RETURNED, THREW }; + +// Routine for aborting the program which is safe to call from an +// exec-style death test child process, in which case the error +// message is propagated back to the parent process. Otherwise, the +// message is simply printed to stderr. In either case, the program +// then exits with status 1. +void DeathTestAbort(const String& message) { + // On a POSIX system, this function may be called from a threadsafe-style + // death test child process, which operates on a very small stack. Use + // the heap for any additional non-minuscule memory requirements. + const InternalRunDeathTestFlag* const flag = + GetUnitTestImpl()->internal_run_death_test_flag(); + if (flag != NULL) { + FILE* parent = posix::FDOpen(flag->write_fd(), "w"); + fputc(kDeathTestInternalError, parent); + fprintf(parent, "%s", message.c_str()); + fflush(parent); + _exit(1); + } else { + fprintf(stderr, "%s", message.c_str()); + fflush(stderr); + posix::Abort(); + } +} + +// A replacement for CHECK that calls DeathTestAbort if the assertion +// fails. +# define GTEST_DEATH_TEST_CHECK_(expression) \ + do { \ + if (!::testing::internal::IsTrue(expression)) { \ + DeathTestAbort(::testing::internal::String::Format( \ + "CHECK failed: File %s, line %d: %s", \ + __FILE__, __LINE__, #expression)); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// This macro is similar to GTEST_DEATH_TEST_CHECK_, but it is meant for +// evaluating any system call that fulfills two conditions: it must return +// -1 on failure, and set errno to EINTR when it is interrupted and +// should be tried again. The macro expands to a loop that repeatedly +// evaluates the expression as long as it evaluates to -1 and sets +// errno to EINTR. If the expression evaluates to -1 but errno is +// something other than EINTR, DeathTestAbort is called. +# define GTEST_DEATH_TEST_CHECK_SYSCALL_(expression) \ + do { \ + int gtest_retval; \ + do { \ + gtest_retval = (expression); \ + } while (gtest_retval == -1 && errno == EINTR); \ + if (gtest_retval == -1) { \ + DeathTestAbort(::testing::internal::String::Format( \ + "CHECK failed: File %s, line %d: %s != -1", \ + __FILE__, __LINE__, #expression)); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// Returns the message describing the last system error in errno. +String GetLastErrnoDescription() { + return String(errno == 0 ? "" : posix::StrError(errno)); +} + +// This is called from a death test parent process to read a failure +// message from the death test child process and log it with the FATAL +// severity. On Windows, the message is read from a pipe handle. On other +// platforms, it is read from a file descriptor. +static void FailFromInternalError(int fd) { + Message error; + char buffer[256]; + int num_read; + + do { + while ((num_read = posix::Read(fd, buffer, 255)) > 0) { + buffer[num_read] = '\0'; + error << buffer; + } + } while (num_read == -1 && errno == EINTR); + + if (num_read == 0) { + GTEST_LOG_(FATAL) << error.GetString(); + } else { + const int last_error = errno; + GTEST_LOG_(FATAL) << "Error while reading death test internal: " + << GetLastErrnoDescription() << " [" << last_error << "]"; + } +} + +// Death test constructor. Increments the running death test count +// for the current test. +DeathTest::DeathTest() { + TestInfo* const info = GetUnitTestImpl()->current_test_info(); + if (info == NULL) { + DeathTestAbort("Cannot run a death test outside of a TEST or " + "TEST_F construct"); + } +} + +// Creates and returns a death test by dispatching to the current +// death test factory. +bool DeathTest::Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test) { + return GetUnitTestImpl()->death_test_factory()->Create( + statement, regex, file, line, test); +} + +const char* DeathTest::LastMessage() { + return last_death_test_message_.c_str(); +} + +void DeathTest::set_last_death_test_message(const String& message) { + last_death_test_message_ = message; +} + +String DeathTest::last_death_test_message_; + +// Provides cross platform implementation for some death functionality. +class DeathTestImpl : public DeathTest { + protected: + DeathTestImpl(const char* a_statement, const RE* a_regex) + : statement_(a_statement), + regex_(a_regex), + spawned_(false), + status_(-1), + outcome_(IN_PROGRESS), + read_fd_(-1), + write_fd_(-1) {} + + // read_fd_ is expected to be closed and cleared by a derived class. + ~DeathTestImpl() { GTEST_DEATH_TEST_CHECK_(read_fd_ == -1); } + + void Abort(AbortReason reason); + virtual bool Passed(bool status_ok); + + const char* statement() const { return statement_; } + const RE* regex() const { return regex_; } + bool spawned() const { return spawned_; } + void set_spawned(bool is_spawned) { spawned_ = is_spawned; } + int status() const { return status_; } + void set_status(int a_status) { status_ = a_status; } + DeathTestOutcome outcome() const { return outcome_; } + void set_outcome(DeathTestOutcome an_outcome) { outcome_ = an_outcome; } + int read_fd() const { return read_fd_; } + void set_read_fd(int fd) { read_fd_ = fd; } + int write_fd() const { return write_fd_; } + void set_write_fd(int fd) { write_fd_ = fd; } + + // Called in the parent process only. Reads the result code of the death + // test child process via a pipe, interprets it to set the outcome_ + // member, and closes read_fd_. Outputs diagnostics and terminates in + // case of unexpected codes. + void ReadAndInterpretStatusByte(); + + private: + // The textual content of the code this object is testing. This class + // doesn't own this string and should not attempt to delete it. + const char* const statement_; + // The regular expression which test output must match. DeathTestImpl + // doesn't own this object and should not attempt to delete it. + const RE* const regex_; + // True if the death test child process has been successfully spawned. + bool spawned_; + // The exit status of the child process. + int status_; + // How the death test concluded. + DeathTestOutcome outcome_; + // Descriptor to the read end of the pipe to the child process. It is + // always -1 in the child process. The child keeps its write end of the + // pipe in write_fd_. + int read_fd_; + // Descriptor to the child's write end of the pipe to the parent process. + // It is always -1 in the parent process. The parent keeps its end of the + // pipe in read_fd_. + int write_fd_; +}; + +// Called in the parent process only. Reads the result code of the death +// test child process via a pipe, interprets it to set the outcome_ +// member, and closes read_fd_. Outputs diagnostics and terminates in +// case of unexpected codes. +void DeathTestImpl::ReadAndInterpretStatusByte() { + char flag; + int bytes_read; + + // The read() here blocks until data is available (signifying the + // failure of the death test) or until the pipe is closed (signifying + // its success), so it's okay to call this in the parent before + // the child process has exited. + do { + bytes_read = posix::Read(read_fd(), &flag, 1); + } while (bytes_read == -1 && errno == EINTR); + + if (bytes_read == 0) { + set_outcome(DIED); + } else if (bytes_read == 1) { + switch (flag) { + case kDeathTestReturned: + set_outcome(RETURNED); + break; + case kDeathTestThrew: + set_outcome(THREW); + break; + case kDeathTestLived: + set_outcome(LIVED); + break; + case kDeathTestInternalError: + FailFromInternalError(read_fd()); // Does not return. + break; + default: + GTEST_LOG_(FATAL) << "Death test child process reported " + << "unexpected status byte (" + << static_cast(flag) << ")"; + } + } else { + GTEST_LOG_(FATAL) << "Read from death test child process failed: " + << GetLastErrnoDescription(); + } + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Close(read_fd())); + set_read_fd(-1); +} + +// Signals that the death test code which should have exited, didn't. +// Should be called only in a death test child process. +// Writes a status byte to the child's status file descriptor, then +// calls _exit(1). +void DeathTestImpl::Abort(AbortReason reason) { + // The parent process considers the death test to be a failure if + // it finds any data in our pipe. So, here we write a single flag byte + // to the pipe, then exit. + const char status_ch = + reason == TEST_DID_NOT_DIE ? kDeathTestLived : + reason == TEST_THREW_EXCEPTION ? kDeathTestThrew : kDeathTestReturned; + + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Write(write_fd(), &status_ch, 1)); + // We are leaking the descriptor here because on some platforms (i.e., + // when built as Windows DLL), destructors of global objects will still + // run after calling _exit(). On such systems, write_fd_ will be + // indirectly closed from the destructor of UnitTestImpl, causing double + // close if it is also closed here. On debug configurations, double close + // may assert. As there are no in-process buffers to flush here, we are + // relying on the OS to close the descriptor after the process terminates + // when the destructors are not run. + _exit(1); // Exits w/o any normal exit hooks (we were supposed to crash) +} + +// Returns an indented copy of stderr output for a death test. +// This makes distinguishing death test output lines from regular log lines +// much easier. +static ::std::string FormatDeathTestOutput(const ::std::string& output) { + ::std::string ret; + for (size_t at = 0; ; ) { + const size_t line_end = output.find('\n', at); + ret += "[ DEATH ] "; + if (line_end == ::std::string::npos) { + ret += output.substr(at); + break; + } + ret += output.substr(at, line_end + 1 - at); + at = line_end + 1; + } + return ret; +} + +// Assesses the success or failure of a death test, using both private +// members which have previously been set, and one argument: +// +// Private data members: +// outcome: An enumeration describing how the death test +// concluded: DIED, LIVED, THREW, or RETURNED. The death test +// fails in the latter three cases. +// status: The exit status of the child process. On *nix, it is in the +// in the format specified by wait(2). On Windows, this is the +// value supplied to the ExitProcess() API or a numeric code +// of the exception that terminated the program. +// regex: A regular expression object to be applied to +// the test's captured standard error output; the death test +// fails if it does not match. +// +// Argument: +// status_ok: true if exit_status is acceptable in the context of +// this particular death test, which fails if it is false +// +// Returns true iff all of the above conditions are met. Otherwise, the +// first failing condition, in the order given above, is the one that is +// reported. Also sets the last death test message string. +bool DeathTestImpl::Passed(bool status_ok) { + if (!spawned()) + return false; + + const String error_message = GetCapturedStderr(); + + bool success = false; + Message buffer; + + buffer << "Death test: " << statement() << "\n"; + switch (outcome()) { + case LIVED: + buffer << " Result: failed to die.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case THREW: + buffer << " Result: threw an exception.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case RETURNED: + buffer << " Result: illegal return in test statement.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case DIED: + if (status_ok) { + const bool matched = RE::PartialMatch(error_message.c_str(), *regex()); + if (matched) { + success = true; + } else { + buffer << " Result: died but not with expected error.\n" + << " Expected: " << regex()->pattern() << "\n" + << "Actual msg:\n" << FormatDeathTestOutput(error_message); + } + } else { + buffer << " Result: died but not with expected exit code:\n" + << " " << ExitSummary(status()) << "\n" + << "Actual msg:\n" << FormatDeathTestOutput(error_message); + } + break; + case IN_PROGRESS: + default: + GTEST_LOG_(FATAL) + << "DeathTest::Passed somehow called before conclusion of test"; + } + + DeathTest::set_last_death_test_message(buffer.GetString()); + return success; +} + +# if GTEST_OS_WINDOWS +// WindowsDeathTest implements death tests on Windows. Due to the +// specifics of starting new processes on Windows, death tests there are +// always threadsafe, and Google Test considers the +// --gtest_death_test_style=fast setting to be equivalent to +// --gtest_death_test_style=threadsafe there. +// +// A few implementation notes: Like the Linux version, the Windows +// implementation uses pipes for child-to-parent communication. But due to +// the specifics of pipes on Windows, some extra steps are required: +// +// 1. The parent creates a communication pipe and stores handles to both +// ends of it. +// 2. The parent starts the child and provides it with the information +// necessary to acquire the handle to the write end of the pipe. +// 3. The child acquires the write end of the pipe and signals the parent +// using a Windows event. +// 4. Now the parent can release the write end of the pipe on its side. If +// this is done before step 3, the object's reference count goes down to +// 0 and it is destroyed, preventing the child from acquiring it. The +// parent now has to release it, or read operations on the read end of +// the pipe will not return when the child terminates. +// 5. The parent reads child's output through the pipe (outcome code and +// any possible error messages) from the pipe, and its stderr and then +// determines whether to fail the test. +// +// Note: to distinguish Win32 API calls from the local method and function +// calls, the former are explicitly resolved in the global namespace. +// +class WindowsDeathTest : public DeathTestImpl { + public: + WindowsDeathTest(const char* a_statement, + const RE* a_regex, + const char* file, + int line) + : DeathTestImpl(a_statement, a_regex), file_(file), line_(line) {} + + // All of these virtual functions are inherited from DeathTest. + virtual int Wait(); + virtual TestRole AssumeRole(); + + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; + // Handle to the write end of the pipe to the child process. + AutoHandle write_handle_; + // Child process handle. + AutoHandle child_handle_; + // Event the child process uses to signal the parent that it has + // acquired the handle to the write end of the pipe. After seeing this + // event the parent can release its own handles to make sure its + // ReadFile() calls return when the child terminates. + AutoHandle event_handle_; +}; + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int WindowsDeathTest::Wait() { + if (!spawned()) + return 0; + + // Wait until the child either signals that it has acquired the write end + // of the pipe or it dies. + const HANDLE wait_handles[2] = { child_handle_.Get(), event_handle_.Get() }; + switch (::WaitForMultipleObjects(2, + wait_handles, + FALSE, // Waits for any of the handles. + INFINITE)) { + case WAIT_OBJECT_0: + case WAIT_OBJECT_0 + 1: + break; + default: + GTEST_DEATH_TEST_CHECK_(false); // Should not get here. + } + + // The child has acquired the write end of the pipe or exited. + // We release the handle on our side and continue. + write_handle_.Reset(); + event_handle_.Reset(); + + ReadAndInterpretStatusByte(); + + // Waits for the child process to exit if it haven't already. This + // returns immediately if the child has already exited, regardless of + // whether previous calls to WaitForMultipleObjects synchronized on this + // handle or not. + GTEST_DEATH_TEST_CHECK_( + WAIT_OBJECT_0 == ::WaitForSingleObject(child_handle_.Get(), + INFINITE)); + DWORD status_code; + GTEST_DEATH_TEST_CHECK_( + ::GetExitCodeProcess(child_handle_.Get(), &status_code) != FALSE); + child_handle_.Reset(); + set_status(static_cast(status_code)); + return status(); +} + +// The AssumeRole process for a Windows death test. It creates a child +// process with the same executable as the current process to run the +// death test. The child process is given the --gtest_filter and +// --gtest_internal_run_death_test flags such that it knows to run the +// current death test only. +DeathTest::TestRole WindowsDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != NULL) { + // ParseInternalRunDeathTestFlag() has performed all the necessary + // processing. + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + // WindowsDeathTest uses an anonymous pipe to communicate results of + // a death test. + SECURITY_ATTRIBUTES handles_are_inheritable = { + sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + HANDLE read_handle, write_handle; + GTEST_DEATH_TEST_CHECK_( + ::CreatePipe(&read_handle, &write_handle, &handles_are_inheritable, + 0) // Default buffer size. + != FALSE); + set_read_fd(::_open_osfhandle(reinterpret_cast(read_handle), + O_RDONLY)); + write_handle_.Reset(write_handle); + event_handle_.Reset(::CreateEvent( + &handles_are_inheritable, + TRUE, // The event will automatically reset to non-signaled state. + FALSE, // The initial state is non-signalled. + NULL)); // The even is unnamed. + GTEST_DEATH_TEST_CHECK_(event_handle_.Get() != NULL); + const String filter_flag = String::Format("--%s%s=%s.%s", + GTEST_FLAG_PREFIX_, kFilterFlag, + info->test_case_name(), + info->name()); + const String internal_flag = String::Format( + "--%s%s=%s|%d|%d|%u|%Iu|%Iu", + GTEST_FLAG_PREFIX_, + kInternalRunDeathTestFlag, + file_, line_, + death_test_index, + static_cast(::GetCurrentProcessId()), + // size_t has the same with as pointers on both 32-bit and 64-bit + // Windows platforms. + // See http://msdn.microsoft.com/en-us/library/tcxf1dw6.aspx. + reinterpret_cast(write_handle), + reinterpret_cast(event_handle_.Get())); + + char executable_path[_MAX_PATH + 1]; // NOLINT + GTEST_DEATH_TEST_CHECK_( + _MAX_PATH + 1 != ::GetModuleFileNameA(NULL, + executable_path, + _MAX_PATH)); + + String command_line = String::Format("%s %s \"%s\"", + ::GetCommandLineA(), + filter_flag.c_str(), + internal_flag.c_str()); + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // Flush the log buffers since the log streams are shared with the child. + FlushInfoLog(); + + // The child process will share the standard handles with the parent. + STARTUPINFOA startup_info; + memset(&startup_info, 0, sizeof(STARTUPINFO)); + startup_info.dwFlags = STARTF_USESTDHANDLES; + startup_info.hStdInput = ::GetStdHandle(STD_INPUT_HANDLE); + startup_info.hStdOutput = ::GetStdHandle(STD_OUTPUT_HANDLE); + startup_info.hStdError = ::GetStdHandle(STD_ERROR_HANDLE); + + PROCESS_INFORMATION process_info; + GTEST_DEATH_TEST_CHECK_(::CreateProcessA( + executable_path, + const_cast(command_line.c_str()), + NULL, // Retuned process handle is not inheritable. + NULL, // Retuned thread handle is not inheritable. + TRUE, // Child inherits all inheritable handles (for write_handle_). + 0x0, // Default creation flags. + NULL, // Inherit the parent's environment. + UnitTest::GetInstance()->original_working_dir(), + &startup_info, + &process_info) != FALSE); + child_handle_.Reset(process_info.hProcess); + ::CloseHandle(process_info.hThread); + set_spawned(true); + return OVERSEE_TEST; +} +# else // We are not on Windows. + +// ForkingDeathTest provides implementations for most of the abstract +// methods of the DeathTest interface. Only the AssumeRole method is +// left undefined. +class ForkingDeathTest : public DeathTestImpl { + public: + ForkingDeathTest(const char* statement, const RE* regex); + + // All of these virtual functions are inherited from DeathTest. + virtual int Wait(); + + protected: + void set_child_pid(pid_t child_pid) { child_pid_ = child_pid; } + + private: + // PID of child process during death test; 0 in the child process itself. + pid_t child_pid_; +}; + +// Constructs a ForkingDeathTest. +ForkingDeathTest::ForkingDeathTest(const char* a_statement, const RE* a_regex) + : DeathTestImpl(a_statement, a_regex), + child_pid_(-1) {} + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int ForkingDeathTest::Wait() { + if (!spawned()) + return 0; + + ReadAndInterpretStatusByte(); + + int status_value; + GTEST_DEATH_TEST_CHECK_SYSCALL_(waitpid(child_pid_, &status_value, 0)); + set_status(status_value); + return status_value; +} + +// A concrete death test class that forks, then immediately runs the test +// in the child process. +class NoExecDeathTest : public ForkingDeathTest { + public: + NoExecDeathTest(const char* a_statement, const RE* a_regex) : + ForkingDeathTest(a_statement, a_regex) { } + virtual TestRole AssumeRole(); +}; + +// The AssumeRole process for a fork-and-run death test. It implements a +// straightforward fork, with a simple pipe to transmit the status byte. +DeathTest::TestRole NoExecDeathTest::AssumeRole() { + const size_t thread_count = GetThreadCount(); + if (thread_count != 1) { + GTEST_LOG_(WARNING) << DeathTestThreadWarning(thread_count); + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + + DeathTest::set_last_death_test_message(""); + CaptureStderr(); + // When we fork the process below, the log file buffers are copied, but the + // file descriptors are shared. We flush all log files here so that closing + // the file descriptors in the child process doesn't throw off the + // synchronization between descriptors and buffers in the parent process. + // This is as close to the fork as possible to avoid a race condition in case + // there are multiple threads running before the death test, and another + // thread writes to the log file. + FlushInfoLog(); + + const pid_t child_pid = fork(); + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + set_child_pid(child_pid); + if (child_pid == 0) { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[0])); + set_write_fd(pipe_fd[1]); + // Redirects all logging to stderr in the child process to prevent + // concurrent writes to the log files. We capture stderr in the parent + // process and append the child process' output to a log. + LogToStderr(); + // Event forwarding to the listeners of event listener API mush be shut + // down in death test subprocesses. + GetUnitTestImpl()->listeners()->SuppressEventForwarding(); + return EXECUTE_TEST; + } else { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; + } +} + +// A concrete death test class that forks and re-executes the main +// program from the beginning, with command-line flags set that cause +// only this specific death test to be run. +class ExecDeathTest : public ForkingDeathTest { + public: + ExecDeathTest(const char* a_statement, const RE* a_regex, + const char* file, int line) : + ForkingDeathTest(a_statement, a_regex), file_(file), line_(line) { } + virtual TestRole AssumeRole(); + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; +}; + +// Utility class for accumulating command-line arguments. +class Arguments { + public: + Arguments() { + args_.push_back(NULL); + } + + ~Arguments() { + for (std::vector::iterator i = args_.begin(); i != args_.end(); + ++i) { + free(*i); + } + } + void AddArgument(const char* argument) { + args_.insert(args_.end() - 1, posix::StrDup(argument)); + } + + template + void AddArguments(const ::std::vector& arguments) { + for (typename ::std::vector::const_iterator i = arguments.begin(); + i != arguments.end(); + ++i) { + args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); + } + } + char* const* Argv() { + return &args_[0]; + } + private: + std::vector args_; +}; + +// A struct that encompasses the arguments to the child process of a +// threadsafe-style death test process. +struct ExecDeathTestArgs { + char* const* argv; // Command-line arguments for the child's call to exec + int close_fd; // File descriptor to close; the read end of a pipe +}; + +# if GTEST_OS_MAC +inline char** GetEnviron() { + // When Google Test is built as a framework on MacOS X, the environ variable + // is unavailable. Apple's documentation (man environ) recommends using + // _NSGetEnviron() instead. + return *_NSGetEnviron(); +} +# else +// Some POSIX platforms expect you to declare environ. extern "C" makes +// it reside in the global namespace. +extern "C" char** environ; +inline char** GetEnviron() { return environ; } +# endif // GTEST_OS_MAC + +// The main function for a threadsafe-style death test child process. +// This function is called in a clone()-ed process and thus must avoid +// any potentially unsafe operations like malloc or libc functions. +static int ExecDeathTestChildMain(void* child_arg) { + ExecDeathTestArgs* const args = static_cast(child_arg); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(args->close_fd)); + + // We need to execute the test program in the same environment where + // it was originally invoked. Therefore we change to the original + // working directory first. + const char* const original_dir = + UnitTest::GetInstance()->original_working_dir(); + // We can safely call chdir() as it's a direct system call. + if (chdir(original_dir) != 0) { + DeathTestAbort(String::Format("chdir(\"%s\") failed: %s", + original_dir, + GetLastErrnoDescription().c_str())); + return EXIT_FAILURE; + } + + // We can safely call execve() as it's a direct system call. We + // cannot use execvp() as it's a libc function and thus potentially + // unsafe. Since execve() doesn't search the PATH, the user must + // invoke the test program via a valid path that contains at least + // one path separator. + execve(args->argv[0], args->argv, GetEnviron()); + DeathTestAbort(String::Format("execve(%s, ...) in %s failed: %s", + args->argv[0], + original_dir, + GetLastErrnoDescription().c_str())); + return EXIT_FAILURE; +} + +// Two utility routines that together determine the direction the stack +// grows. +// This could be accomplished more elegantly by a single recursive +// function, but we want to guard against the unlikely possibility of +// a smart compiler optimizing the recursion away. +// +// GTEST_NO_INLINE_ is required to prevent GCC 4.6 from inlining +// StackLowerThanAddress into StackGrowsDown, which then doesn't give +// correct answer. +bool StackLowerThanAddress(const void* ptr) GTEST_NO_INLINE_; +bool StackLowerThanAddress(const void* ptr) { + int dummy; + return &dummy < ptr; +} + +bool StackGrowsDown() { + int dummy; + return StackLowerThanAddress(&dummy); +} + +// A threadsafe implementation of fork(2) for threadsafe-style death tests +// that uses clone(2). It dies with an error message if anything goes +// wrong. +static pid_t ExecDeathTestFork(char* const* argv, int close_fd) { + ExecDeathTestArgs args = { argv, close_fd }; + pid_t child_pid = -1; + +# if GTEST_HAS_CLONE + const bool use_fork = GTEST_FLAG(death_test_use_fork); + + if (!use_fork) { + static const bool stack_grows_down = StackGrowsDown(); + const size_t stack_size = getpagesize(); + // MMAP_ANONYMOUS is not defined on Mac, so we use MAP_ANON instead. + void* const stack = mmap(NULL, stack_size, PROT_READ | PROT_WRITE, + MAP_ANON | MAP_PRIVATE, -1, 0); + GTEST_DEATH_TEST_CHECK_(stack != MAP_FAILED); + void* const stack_top = + static_cast(stack) + (stack_grows_down ? stack_size : 0); + + child_pid = clone(&ExecDeathTestChildMain, stack_top, SIGCHLD, &args); + + GTEST_DEATH_TEST_CHECK_(munmap(stack, stack_size) != -1); + } +# else + const bool use_fork = true; +# endif // GTEST_HAS_CLONE + + if (use_fork && (child_pid = fork()) == 0) { + ExecDeathTestChildMain(&args); + _exit(0); + } + + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + return child_pid; +} + +// The AssumeRole process for a fork-and-exec death test. It re-executes the +// main program from the beginning, setting the --gtest_filter +// and --gtest_internal_run_death_test flags to cause only the current +// death test to be re-run. +DeathTest::TestRole ExecDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != NULL) { + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + // Clear the close-on-exec flag on the write end of the pipe, lest + // it be closed when the child process does an exec: + GTEST_DEATH_TEST_CHECK_(fcntl(pipe_fd[1], F_SETFD, 0) != -1); + + const String filter_flag = + String::Format("--%s%s=%s.%s", + GTEST_FLAG_PREFIX_, kFilterFlag, + info->test_case_name(), info->name()); + const String internal_flag = + String::Format("--%s%s=%s|%d|%d|%d", + GTEST_FLAG_PREFIX_, kInternalRunDeathTestFlag, + file_, line_, death_test_index, pipe_fd[1]); + Arguments args; + args.AddArguments(GetArgvs()); + args.AddArgument(filter_flag.c_str()); + args.AddArgument(internal_flag.c_str()); + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // See the comment in NoExecDeathTest::AssumeRole for why the next line + // is necessary. + FlushInfoLog(); + + const pid_t child_pid = ExecDeathTestFork(args.Argv(), pipe_fd[0]); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_child_pid(child_pid); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; +} + +# endif // !GTEST_OS_WINDOWS + +// Creates a concrete DeathTest-derived class that depends on the +// --gtest_death_test_style flag, and sets the pointer pointed to +// by the "test" argument to its address. If the test should be +// skipped, sets that pointer to NULL. Returns true, unless the +// flag is set to an invalid value. +bool DefaultDeathTestFactory::Create(const char* statement, const RE* regex, + const char* file, int line, + DeathTest** test) { + UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const int death_test_index = impl->current_test_info() + ->increment_death_test_count(); + + if (flag != NULL) { + if (death_test_index > flag->index()) { + DeathTest::set_last_death_test_message(String::Format( + "Death test count (%d) somehow exceeded expected maximum (%d)", + death_test_index, flag->index())); + return false; + } + + if (!(flag->file() == file && flag->line() == line && + flag->index() == death_test_index)) { + *test = NULL; + return true; + } + } + +# if GTEST_OS_WINDOWS + + if (GTEST_FLAG(death_test_style) == "threadsafe" || + GTEST_FLAG(death_test_style) == "fast") { + *test = new WindowsDeathTest(statement, regex, file, line); + } + +# else + + if (GTEST_FLAG(death_test_style) == "threadsafe") { + *test = new ExecDeathTest(statement, regex, file, line); + } else if (GTEST_FLAG(death_test_style) == "fast") { + *test = new NoExecDeathTest(statement, regex); + } + +# endif // GTEST_OS_WINDOWS + + else { // NOLINT - this is more readable than unbalanced brackets inside #if. + DeathTest::set_last_death_test_message(String::Format( + "Unknown death test style \"%s\" encountered", + GTEST_FLAG(death_test_style).c_str())); + return false; + } + + return true; +} + +// Splits a given string on a given delimiter, populating a given +// vector with the fields. GTEST_HAS_DEATH_TEST implies that we have +// ::std::string, so we can use it here. +static void SplitString(const ::std::string& str, char delimiter, + ::std::vector< ::std::string>* dest) { + ::std::vector< ::std::string> parsed; + ::std::string::size_type pos = 0; + while (::testing::internal::AlwaysTrue()) { + const ::std::string::size_type colon = str.find(delimiter, pos); + if (colon == ::std::string::npos) { + parsed.push_back(str.substr(pos)); + break; + } else { + parsed.push_back(str.substr(pos, colon - pos)); + pos = colon + 1; + } + } + dest->swap(parsed); +} + +# if GTEST_OS_WINDOWS +// Recreates the pipe and event handles from the provided parameters, +// signals the event, and returns a file descriptor wrapped around the pipe +// handle. This function is called in the child process only. +int GetStatusFileDescriptor(unsigned int parent_process_id, + size_t write_handle_as_size_t, + size_t event_handle_as_size_t) { + AutoHandle parent_process_handle(::OpenProcess(PROCESS_DUP_HANDLE, + FALSE, // Non-inheritable. + parent_process_id)); + if (parent_process_handle.Get() == INVALID_HANDLE_VALUE) { + DeathTestAbort(String::Format("Unable to open parent process %u", + parent_process_id)); + } + + // TODO(vladl@google.com): Replace the following check with a + // compile-time assertion when available. + GTEST_CHECK_(sizeof(HANDLE) <= sizeof(size_t)); + + const HANDLE write_handle = + reinterpret_cast(write_handle_as_size_t); + HANDLE dup_write_handle; + + // The newly initialized handle is accessible only in in the parent + // process. To obtain one accessible within the child, we need to use + // DuplicateHandle. + if (!::DuplicateHandle(parent_process_handle.Get(), write_handle, + ::GetCurrentProcess(), &dup_write_handle, + 0x0, // Requested privileges ignored since + // DUPLICATE_SAME_ACCESS is used. + FALSE, // Request non-inheritable handler. + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort(String::Format( + "Unable to duplicate the pipe handle %Iu from the parent process %u", + write_handle_as_size_t, parent_process_id)); + } + + const HANDLE event_handle = reinterpret_cast(event_handle_as_size_t); + HANDLE dup_event_handle; + + if (!::DuplicateHandle(parent_process_handle.Get(), event_handle, + ::GetCurrentProcess(), &dup_event_handle, + 0x0, + FALSE, + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort(String::Format( + "Unable to duplicate the event handle %Iu from the parent process %u", + event_handle_as_size_t, parent_process_id)); + } + + const int write_fd = + ::_open_osfhandle(reinterpret_cast(dup_write_handle), O_APPEND); + if (write_fd == -1) { + DeathTestAbort(String::Format( + "Unable to convert pipe handle %Iu to a file descriptor", + write_handle_as_size_t)); + } + + // Signals the parent that the write end of the pipe has been acquired + // so the parent can release its own write end. + ::SetEvent(dup_event_handle); + + return write_fd; +} +# endif // GTEST_OS_WINDOWS + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag() { + if (GTEST_FLAG(internal_run_death_test) == "") return NULL; + + // GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we + // can use it here. + int line = -1; + int index = -1; + ::std::vector< ::std::string> fields; + SplitString(GTEST_FLAG(internal_run_death_test).c_str(), '|', &fields); + int write_fd = -1; + +# if GTEST_OS_WINDOWS + + unsigned int parent_process_id = 0; + size_t write_handle_as_size_t = 0; + size_t event_handle_as_size_t = 0; + + if (fields.size() != 6 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &parent_process_id) + || !ParseNaturalNumber(fields[4], &write_handle_as_size_t) + || !ParseNaturalNumber(fields[5], &event_handle_as_size_t)) { + DeathTestAbort(String::Format( + "Bad --gtest_internal_run_death_test flag: %s", + GTEST_FLAG(internal_run_death_test).c_str())); + } + write_fd = GetStatusFileDescriptor(parent_process_id, + write_handle_as_size_t, + event_handle_as_size_t); +# else + + if (fields.size() != 4 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &write_fd)) { + DeathTestAbort(String::Format( + "Bad --gtest_internal_run_death_test flag: %s", + GTEST_FLAG(internal_run_death_test).c_str())); + } + +# endif // GTEST_OS_WINDOWS + + return new InternalRunDeathTestFlag(fields[0], line, index, write_fd); +} + +} // namespace internal + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: keith.ray@gmail.com (Keith Ray) + + +#include + +#if GTEST_OS_WINDOWS_MOBILE +# include +#elif GTEST_OS_WINDOWS +# include +# include +#elif GTEST_OS_SYMBIAN || GTEST_OS_NACL +// Symbian OpenC and NaCl have PATH_MAX in sys/syslimits.h +# include +#else +# include +# include // Some Linux distributions define PATH_MAX here. +#endif // GTEST_OS_WINDOWS_MOBILE + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_MAX_ _MAX_PATH +#elif defined(PATH_MAX) +# define GTEST_PATH_MAX_ PATH_MAX +#elif defined(_XOPEN_PATH_MAX) +# define GTEST_PATH_MAX_ _XOPEN_PATH_MAX +#else +# define GTEST_PATH_MAX_ _POSIX_PATH_MAX +#endif // GTEST_OS_WINDOWS + + +namespace testing { +namespace internal { + +#if GTEST_OS_WINDOWS +// On Windows, '\\' is the standard path separator, but many tools and the +// Windows API also accept '/' as an alternate path separator. Unless otherwise +// noted, a file path can contain either kind of path separators, or a mixture +// of them. +const char kPathSeparator = '\\'; +const char kAlternatePathSeparator = '/'; +const char kPathSeparatorString[] = "\\"; +const char kAlternatePathSeparatorString[] = "/"; +# if GTEST_OS_WINDOWS_MOBILE +// Windows CE doesn't have a current directory. You should not use +// the current directory in tests on Windows CE, but this at least +// provides a reasonable fallback. +const char kCurrentDirectoryString[] = "\\"; +// Windows CE doesn't define INVALID_FILE_ATTRIBUTES +const DWORD kInvalidFileAttributes = 0xffffffff; +# else +const char kCurrentDirectoryString[] = ".\\"; +# endif // GTEST_OS_WINDOWS_MOBILE +#else +const char kPathSeparator = '/'; +const char kPathSeparatorString[] = "/"; +const char kCurrentDirectoryString[] = "./"; +#endif // GTEST_OS_WINDOWS + +// Returns whether the given character is a valid path separator. +static bool IsPathSeparator(char c) { +#if GTEST_HAS_ALT_PATH_SEP_ + return (c == kPathSeparator) || (c == kAlternatePathSeparator); +#else + return c == kPathSeparator; +#endif +} + +// Returns the current working directory, or "" if unsuccessful. +FilePath FilePath::GetCurrentDir() { +#if GTEST_OS_WINDOWS_MOBILE + // Windows CE doesn't have a current directory, so we just return + // something reasonable. + return FilePath(kCurrentDirectoryString); +#elif GTEST_OS_WINDOWS + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + return FilePath(_getcwd(cwd, sizeof(cwd)) == NULL ? "" : cwd); +#else + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + return FilePath(getcwd(cwd, sizeof(cwd)) == NULL ? "" : cwd); +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns a copy of the FilePath with the case-insensitive extension removed. +// Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns +// FilePath("dir/file"). If a case-insensitive extension is not +// found, returns a copy of the original FilePath. +FilePath FilePath::RemoveExtension(const char* extension) const { + String dot_extension(String::Format(".%s", extension)); + if (pathname_.EndsWithCaseInsensitive(dot_extension.c_str())) { + return FilePath(String(pathname_.c_str(), pathname_.length() - 4)); + } + return *this; +} + +// Returns a pointer to the last occurence of a valid path separator in +// the FilePath. On Windows, for example, both '/' and '\' are valid path +// separators. Returns NULL if no path separator was found. +const char* FilePath::FindLastPathSeparator() const { + const char* const last_sep = strrchr(c_str(), kPathSeparator); +#if GTEST_HAS_ALT_PATH_SEP_ + const char* const last_alt_sep = strrchr(c_str(), kAlternatePathSeparator); + // Comparing two pointers of which only one is NULL is undefined. + if (last_alt_sep != NULL && + (last_sep == NULL || last_alt_sep > last_sep)) { + return last_alt_sep; + } +#endif + return last_sep; +} + +// Returns a copy of the FilePath with the directory part removed. +// Example: FilePath("path/to/file").RemoveDirectoryName() returns +// FilePath("file"). If there is no directory part ("just_a_file"), it returns +// the FilePath unmodified. If there is no file part ("just_a_dir/") it +// returns an empty FilePath (""). +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveDirectoryName() const { + const char* const last_sep = FindLastPathSeparator(); + return last_sep ? FilePath(String(last_sep + 1)) : *this; +} + +// RemoveFileName returns the directory path with the filename removed. +// Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". +// If the FilePath is "a_file" or "/a_file", RemoveFileName returns +// FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does +// not have a file, like "just/a/dir/", it returns the FilePath unmodified. +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveFileName() const { + const char* const last_sep = FindLastPathSeparator(); + String dir; + if (last_sep) { + dir = String(c_str(), last_sep + 1 - c_str()); + } else { + dir = kCurrentDirectoryString; + } + return FilePath(dir); +} + +// Helper functions for naming files in a directory for xml output. + +// Given directory = "dir", base_name = "test", number = 0, +// extension = "xml", returns "dir/test.xml". If number is greater +// than zero (e.g., 12), returns "dir/test_12.xml". +// On Windows platform, uses \ as the separator rather than /. +FilePath FilePath::MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension) { + String file; + if (number == 0) { + file = String::Format("%s.%s", base_name.c_str(), extension); + } else { + file = String::Format("%s_%d.%s", base_name.c_str(), number, extension); + } + return ConcatPaths(directory, FilePath(file)); +} + +// Given directory = "dir", relative_path = "test.xml", returns "dir/test.xml". +// On Windows, uses \ as the separator rather than /. +FilePath FilePath::ConcatPaths(const FilePath& directory, + const FilePath& relative_path) { + if (directory.IsEmpty()) + return relative_path; + const FilePath dir(directory.RemoveTrailingPathSeparator()); + return FilePath(String::Format("%s%c%s", dir.c_str(), kPathSeparator, + relative_path.c_str())); +} + +// Returns true if pathname describes something findable in the file-system, +// either a file, directory, or whatever. +bool FilePath::FileOrDirectoryExists() const { +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(pathname_.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + return attributes != kInvalidFileAttributes; +#else + posix::StatStruct file_stat; + return posix::Stat(pathname_.c_str(), &file_stat) == 0; +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns true if pathname describes a directory in the file-system +// that exists. +bool FilePath::DirectoryExists() const { + bool result = false; +#if GTEST_OS_WINDOWS + // Don't strip off trailing separator if path is a root directory on + // Windows (like "C:\\"). + const FilePath& path(IsRootDirectory() ? *this : + RemoveTrailingPathSeparator()); +#else + const FilePath& path(*this); +#endif + +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(path.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + if ((attributes != kInvalidFileAttributes) && + (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + result = true; + } +#else + posix::StatStruct file_stat; + result = posix::Stat(path.c_str(), &file_stat) == 0 && + posix::IsDir(file_stat); +#endif // GTEST_OS_WINDOWS_MOBILE + + return result; +} + +// Returns true if pathname describes a root directory. (Windows has one +// root directory per disk drive.) +bool FilePath::IsRootDirectory() const { +#if GTEST_OS_WINDOWS + // TODO(wan@google.com): on Windows a network share like + // \\server\share can be a root directory, although it cannot be the + // current directory. Handle this properly. + return pathname_.length() == 3 && IsAbsolutePath(); +#else + return pathname_.length() == 1 && IsPathSeparator(pathname_.c_str()[0]); +#endif +} + +// Returns true if pathname describes an absolute path. +bool FilePath::IsAbsolutePath() const { + const char* const name = pathname_.c_str(); +#if GTEST_OS_WINDOWS + return pathname_.length() >= 3 && + ((name[0] >= 'a' && name[0] <= 'z') || + (name[0] >= 'A' && name[0] <= 'Z')) && + name[1] == ':' && + IsPathSeparator(name[2]); +#else + return IsPathSeparator(name[0]); +#endif +} + +// Returns a pathname for a file that does not currently exist. The pathname +// will be directory/base_name.extension or +// directory/base_name_.extension if directory/base_name.extension +// already exists. The number will be incremented until a pathname is found +// that does not already exist. +// Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. +// There could be a race condition if two or more processes are calling this +// function at the same time -- they could both pick the same filename. +FilePath FilePath::GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension) { + FilePath full_pathname; + int number = 0; + do { + full_pathname.Set(MakeFileName(directory, base_name, number++, extension)); + } while (full_pathname.FileOrDirectoryExists()); + return full_pathname; +} + +// Returns true if FilePath ends with a path separator, which indicates that +// it is intended to represent a directory. Returns false otherwise. +// This does NOT check that a directory (or file) actually exists. +bool FilePath::IsDirectory() const { + return !pathname_.empty() && + IsPathSeparator(pathname_.c_str()[pathname_.length() - 1]); +} + +// Create directories so that path exists. Returns true if successful or if +// the directories already exist; returns false if unable to create directories +// for any reason. +bool FilePath::CreateDirectoriesRecursively() const { + if (!this->IsDirectory()) { + return false; + } + + if (pathname_.length() == 0 || this->DirectoryExists()) { + return true; + } + + const FilePath parent(this->RemoveTrailingPathSeparator().RemoveFileName()); + return parent.CreateDirectoriesRecursively() && this->CreateFolder(); +} + +// Create the directory so that path exists. Returns true if successful or +// if the directory already exists; returns false if unable to create the +// directory for any reason, including if the parent directory does not +// exist. Not named "CreateDirectory" because that's a macro on Windows. +bool FilePath::CreateFolder() const { +#if GTEST_OS_WINDOWS_MOBILE + FilePath removed_sep(this->RemoveTrailingPathSeparator()); + LPCWSTR unicode = String::AnsiToUtf16(removed_sep.c_str()); + int result = CreateDirectory(unicode, NULL) ? 0 : -1; + delete [] unicode; +#elif GTEST_OS_WINDOWS + int result = _mkdir(pathname_.c_str()); +#else + int result = mkdir(pathname_.c_str(), 0777); +#endif // GTEST_OS_WINDOWS_MOBILE + + if (result == -1) { + return this->DirectoryExists(); // An error is OK if the directory exists. + } + return true; // No error. +} + +// If input name has a trailing separator character, remove it and return the +// name, otherwise return the name string unmodified. +// On Windows platform, uses \ as the separator, other platforms use /. +FilePath FilePath::RemoveTrailingPathSeparator() const { + return IsDirectory() + ? FilePath(String(pathname_.c_str(), pathname_.length() - 1)) + : *this; +} + +// Removes any redundant separators that might be in the pathname. +// For example, "bar///foo" becomes "bar/foo". Does not eliminate other +// redundancies that might be in a pathname involving "." or "..". +// TODO(wan@google.com): handle Windows network shares (e.g. \\server\share). +void FilePath::Normalize() { + if (pathname_.c_str() == NULL) { + pathname_ = ""; + return; + } + const char* src = pathname_.c_str(); + char* const dest = new char[pathname_.length() + 1]; + char* dest_ptr = dest; + memset(dest_ptr, 0, pathname_.length() + 1); + + while (*src != '\0') { + *dest_ptr = *src; + if (!IsPathSeparator(*src)) { + src++; + } else { +#if GTEST_HAS_ALT_PATH_SEP_ + if (*dest_ptr == kAlternatePathSeparator) { + *dest_ptr = kPathSeparator; + } +#endif + while (IsPathSeparator(*src)) + src++; + } + dest_ptr++; + } + *dest_ptr = '\0'; + pathname_ = dest; + delete[] dest; +} + +} // namespace internal +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + + +#include +#include +#include +#include + +#if GTEST_OS_WINDOWS_MOBILE +# include // For TerminateProcess() +#elif GTEST_OS_WINDOWS +# include +# include +#else +# include +#endif // GTEST_OS_WINDOWS_MOBILE + +#if GTEST_OS_MAC +# include +# include +# include +#endif // GTEST_OS_MAC + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { +namespace internal { + +#if defined(_MSC_VER) || defined(__BORLANDC__) +// MSVC and C++Builder do not provide a definition of STDERR_FILENO. +const int kStdOutFileno = 1; +const int kStdErrFileno = 2; +#else +const int kStdOutFileno = STDOUT_FILENO; +const int kStdErrFileno = STDERR_FILENO; +#endif // _MSC_VER + +#if GTEST_OS_MAC + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +size_t GetThreadCount() { + const task_t task = mach_task_self(); + mach_msg_type_number_t thread_count; + thread_act_array_t thread_list; + const kern_return_t status = task_threads(task, &thread_list, &thread_count); + if (status == KERN_SUCCESS) { + // task_threads allocates resources in thread_list and we need to free them + // to avoid leaks. + vm_deallocate(task, + reinterpret_cast(thread_list), + sizeof(thread_t) * thread_count); + return static_cast(thread_count); + } else { + return 0; + } +} + +#else + +size_t GetThreadCount() { + // There's no portable way to detect the number of threads, so we just + // return 0 to indicate that we cannot detect it. + return 0; +} + +#endif // GTEST_OS_MAC + +#if GTEST_USES_POSIX_RE + +// Implements RE. Currently only needed for death tests. + +RE::~RE() { + if (is_valid_) { + // regfree'ing an invalid regex might crash because the content + // of the regex is undefined. Since the regex's are essentially + // the same, one cannot be valid (or invalid) without the other + // being so too. + regfree(&partial_regex_); + regfree(&full_regex_); + } + free(const_cast(pattern_)); +} + +// Returns true iff regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.full_regex_, str, 1, &match, 0) == 0; +} + +// Returns true iff regular expression re matches a substring of str +// (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.partial_regex_, str, 1, &match, 0) == 0; +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = posix::StrDup(regex); + + // Reserves enough bytes to hold the regular expression used for a + // full match. + const size_t full_regex_len = strlen(regex) + 10; + char* const full_pattern = new char[full_regex_len]; + + snprintf(full_pattern, full_regex_len, "^(%s)$", regex); + is_valid_ = regcomp(&full_regex_, full_pattern, REG_EXTENDED) == 0; + // We want to call regcomp(&partial_regex_, ...) even if the + // previous expression returns false. Otherwise partial_regex_ may + // not be properly initialized can may cause trouble when it's + // freed. + // + // Some implementation of POSIX regex (e.g. on at least some + // versions of Cygwin) doesn't accept the empty string as a valid + // regex. We change it to an equivalent form "()" to be safe. + if (is_valid_) { + const char* const partial_regex = (*regex == '\0') ? "()" : regex; + is_valid_ = regcomp(&partial_regex_, partial_regex, REG_EXTENDED) == 0; + } + EXPECT_TRUE(is_valid_) + << "Regular expression \"" << regex + << "\" is not a valid POSIX Extended regular expression."; + + delete[] full_pattern; +} + +#elif GTEST_USES_SIMPLE_RE + +// Returns true iff ch appears anywhere in str (excluding the +// terminating '\0' character). +bool IsInSet(char ch, const char* str) { + return ch != '\0' && strchr(str, ch) != NULL; +} + +// Returns true iff ch belongs to the given classification. Unlike +// similar functions in , these aren't affected by the +// current locale. +bool IsAsciiDigit(char ch) { return '0' <= ch && ch <= '9'; } +bool IsAsciiPunct(char ch) { + return IsInSet(ch, "^-!\"#$%&'()*+,./:;<=>?@[\\]_`{|}~"); +} +bool IsRepeat(char ch) { return IsInSet(ch, "?*+"); } +bool IsAsciiWhiteSpace(char ch) { return IsInSet(ch, " \f\n\r\t\v"); } +bool IsAsciiWordChar(char ch) { + return ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || + ('0' <= ch && ch <= '9') || ch == '_'; +} + +// Returns true iff "\\c" is a supported escape sequence. +bool IsValidEscape(char c) { + return (IsAsciiPunct(c) || IsInSet(c, "dDfnrsStvwW")); +} + +// Returns true iff the given atom (specified by escaped and pattern) +// matches ch. The result is undefined if the atom is invalid. +bool AtomMatchesChar(bool escaped, char pattern_char, char ch) { + if (escaped) { // "\\p" where p is pattern_char. + switch (pattern_char) { + case 'd': return IsAsciiDigit(ch); + case 'D': return !IsAsciiDigit(ch); + case 'f': return ch == '\f'; + case 'n': return ch == '\n'; + case 'r': return ch == '\r'; + case 's': return IsAsciiWhiteSpace(ch); + case 'S': return !IsAsciiWhiteSpace(ch); + case 't': return ch == '\t'; + case 'v': return ch == '\v'; + case 'w': return IsAsciiWordChar(ch); + case 'W': return !IsAsciiWordChar(ch); + } + return IsAsciiPunct(pattern_char) && pattern_char == ch; + } + + return (pattern_char == '.' && ch != '\n') || pattern_char == ch; +} + +// Helper function used by ValidateRegex() to format error messages. +String FormatRegexSyntaxError(const char* regex, int index) { + return (Message() << "Syntax error at index " << index + << " in simple regular expression \"" << regex << "\": ").GetString(); +} + +// Generates non-fatal failures and returns false if regex is invalid; +// otherwise returns true. +bool ValidateRegex(const char* regex) { + if (regex == NULL) { + // TODO(wan@google.com): fix the source file location in the + // assertion failures to match where the regex is used in user + // code. + ADD_FAILURE() << "NULL is not a valid simple regular expression."; + return false; + } + + bool is_valid = true; + + // True iff ?, *, or + can follow the previous atom. + bool prev_repeatable = false; + for (int i = 0; regex[i]; i++) { + if (regex[i] == '\\') { // An escape sequence + i++; + if (regex[i] == '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "'\\' cannot appear at the end."; + return false; + } + + if (!IsValidEscape(regex[i])) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "invalid escape sequence \"\\" << regex[i] << "\"."; + is_valid = false; + } + prev_repeatable = true; + } else { // Not an escape sequence. + const char ch = regex[i]; + + if (ch == '^' && i > 0) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'^' can only appear at the beginning."; + is_valid = false; + } else if (ch == '$' && regex[i + 1] != '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'$' can only appear at the end."; + is_valid = false; + } else if (IsInSet(ch, "()[]{}|")) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' is unsupported."; + is_valid = false; + } else if (IsRepeat(ch) && !prev_repeatable) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' can only follow a repeatable token."; + is_valid = false; + } + + prev_repeatable = !IsInSet(ch, "^$?*+"); + } + } + + return is_valid; +} + +// Matches a repeated regex atom followed by a valid simple regular +// expression. The regex atom is defined as c if escaped is false, +// or \c otherwise. repeat is the repetition meta character (?, *, +// or +). The behavior is undefined if str contains too many +// characters to be indexable by size_t, in which case the test will +// probably time out anyway. We are fine with this limitation as +// std::string has it too. +bool MatchRepetitionAndRegexAtHead( + bool escaped, char c, char repeat, const char* regex, + const char* str) { + const size_t min_count = (repeat == '+') ? 1 : 0; + const size_t max_count = (repeat == '?') ? 1 : + static_cast(-1) - 1; + // We cannot call numeric_limits::max() as it conflicts with the + // max() macro on Windows. + + for (size_t i = 0; i <= max_count; ++i) { + // We know that the atom matches each of the first i characters in str. + if (i >= min_count && MatchRegexAtHead(regex, str + i)) { + // We have enough matches at the head, and the tail matches too. + // Since we only care about *whether* the pattern matches str + // (as opposed to *how* it matches), there is no need to find a + // greedy match. + return true; + } + if (str[i] == '\0' || !AtomMatchesChar(escaped, c, str[i])) + return false; + } + return false; +} + +// Returns true iff regex matches a prefix of str. regex must be a +// valid simple regular expression and not start with "^", or the +// result is undefined. +bool MatchRegexAtHead(const char* regex, const char* str) { + if (*regex == '\0') // An empty regex matches a prefix of anything. + return true; + + // "$" only matches the end of a string. Note that regex being + // valid guarantees that there's nothing after "$" in it. + if (*regex == '$') + return *str == '\0'; + + // Is the first thing in regex an escape sequence? + const bool escaped = *regex == '\\'; + if (escaped) + ++regex; + if (IsRepeat(regex[1])) { + // MatchRepetitionAndRegexAtHead() calls MatchRegexAtHead(), so + // here's an indirect recursion. It terminates as the regex gets + // shorter in each recursion. + return MatchRepetitionAndRegexAtHead( + escaped, regex[0], regex[1], regex + 2, str); + } else { + // regex isn't empty, isn't "$", and doesn't start with a + // repetition. We match the first atom of regex with the first + // character of str and recurse. + return (*str != '\0') && AtomMatchesChar(escaped, *regex, *str) && + MatchRegexAtHead(regex + 1, str + 1); + } +} + +// Returns true iff regex matches any substring of str. regex must be +// a valid simple regular expression, or the result is undefined. +// +// The algorithm is recursive, but the recursion depth doesn't exceed +// the regex length, so we won't need to worry about running out of +// stack space normally. In rare cases the time complexity can be +// exponential with respect to the regex length + the string length, +// but usually it's must faster (often close to linear). +bool MatchRegexAnywhere(const char* regex, const char* str) { + if (regex == NULL || str == NULL) + return false; + + if (*regex == '^') + return MatchRegexAtHead(regex + 1, str); + + // A successful match can be anywhere in str. + do { + if (MatchRegexAtHead(regex, str)) + return true; + } while (*str++ != '\0'); + return false; +} + +// Implements the RE class. + +RE::~RE() { + free(const_cast(pattern_)); + free(const_cast(full_pattern_)); +} + +// Returns true iff regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.full_pattern_, str); +} + +// Returns true iff regular expression re matches a substring of str +// (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.pattern_, str); +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = full_pattern_ = NULL; + if (regex != NULL) { + pattern_ = posix::StrDup(regex); + } + + is_valid_ = ValidateRegex(regex); + if (!is_valid_) { + // No need to calculate the full pattern when the regex is invalid. + return; + } + + const size_t len = strlen(regex); + // Reserves enough bytes to hold the regular expression used for a + // full match: we need space to prepend a '^', append a '$', and + // terminate the string with '\0'. + char* buffer = static_cast(malloc(len + 3)); + full_pattern_ = buffer; + + if (*regex != '^') + *buffer++ = '^'; // Makes sure full_pattern_ starts with '^'. + + // We don't use snprintf or strncpy, as they trigger a warning when + // compiled with VC++ 8.0. + memcpy(buffer, regex, len); + buffer += len; + + if (len == 0 || regex[len - 1] != '$') + *buffer++ = '$'; // Makes sure full_pattern_ ends with '$'. + + *buffer = '\0'; +} + +#endif // GTEST_USES_POSIX_RE + +const char kUnknownFile[] = "unknown file"; + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line) { + const char* const file_name = file == NULL ? kUnknownFile : file; + + if (line < 0) { + return String::Format("%s:", file_name).c_str(); + } +#ifdef _MSC_VER + return String::Format("%s(%d):", file_name, line).c_str(); +#else + return String::Format("%s:%d:", file_name, line).c_str(); +#endif // _MSC_VER +} + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +// Note that FormatCompilerIndependentFileLocation() does NOT append colon +// to the file location it produces, unlike FormatFileLocation(). +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation( + const char* file, int line) { + const char* const file_name = file == NULL ? kUnknownFile : file; + + if (line < 0) + return file_name; + else + return String::Format("%s:%d", file_name, line).c_str(); +} + + +GTestLog::GTestLog(GTestLogSeverity severity, const char* file, int line) + : severity_(severity) { + const char* const marker = + severity == GTEST_INFO ? "[ INFO ]" : + severity == GTEST_WARNING ? "[WARNING]" : + severity == GTEST_ERROR ? "[ ERROR ]" : "[ FATAL ]"; + GetStream() << ::std::endl << marker << " " + << FormatFileLocation(file, line).c_str() << ": "; +} + +// Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. +GTestLog::~GTestLog() { + GetStream() << ::std::endl; + if (severity_ == GTEST_FATAL) { + fflush(stderr); + posix::Abort(); + } +} +// Disable Microsoft deprecation warnings for POSIX functions called from +// this class (creat, dup, dup2, and close) +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4996) +#endif // _MSC_VER + +#if GTEST_HAS_STREAM_REDIRECTION + +// Object that captures an output stream (stdout/stderr). +class CapturedStream { + public: + // The ctor redirects the stream to a temporary file. + CapturedStream(int fd) : fd_(fd), uncaptured_fd_(dup(fd)) { + +# if GTEST_OS_WINDOWS + char temp_dir_path[MAX_PATH + 1] = { '\0' }; // NOLINT + char temp_file_path[MAX_PATH + 1] = { '\0' }; // NOLINT + + ::GetTempPathA(sizeof(temp_dir_path), temp_dir_path); + const UINT success = ::GetTempFileNameA(temp_dir_path, + "gtest_redir", + 0, // Generate unique file name. + temp_file_path); + GTEST_CHECK_(success != 0) + << "Unable to create a temporary file in " << temp_dir_path; + const int captured_fd = creat(temp_file_path, _S_IREAD | _S_IWRITE); + GTEST_CHECK_(captured_fd != -1) << "Unable to open temporary file " + << temp_file_path; + filename_ = temp_file_path; +# else + // There's no guarantee that a test has write access to the + // current directory, so we create the temporary file in the /tmp + // directory instead. + char name_template[] = "/tmp/captured_stream.XXXXXX"; + const int captured_fd = mkstemp(name_template); + filename_ = name_template; +# endif // GTEST_OS_WINDOWS + fflush(NULL); + dup2(captured_fd, fd_); + close(captured_fd); + } + + ~CapturedStream() { + remove(filename_.c_str()); + } + + String GetCapturedString() { + if (uncaptured_fd_ != -1) { + // Restores the original stream. + fflush(NULL); + dup2(uncaptured_fd_, fd_); + close(uncaptured_fd_); + uncaptured_fd_ = -1; + } + + FILE* const file = posix::FOpen(filename_.c_str(), "r"); + const String content = ReadEntireFile(file); + posix::FClose(file); + return content; + } + + private: + // Reads the entire content of a file as a String. + static String ReadEntireFile(FILE* file); + + // Returns the size (in bytes) of a file. + static size_t GetFileSize(FILE* file); + + const int fd_; // A stream to capture. + int uncaptured_fd_; + // Name of the temporary file holding the stderr output. + ::std::string filename_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(CapturedStream); +}; + +// Returns the size (in bytes) of a file. +size_t CapturedStream::GetFileSize(FILE* file) { + fseek(file, 0, SEEK_END); + return static_cast(ftell(file)); +} + +// Reads the entire content of a file as a string. +String CapturedStream::ReadEntireFile(FILE* file) { + const size_t file_size = GetFileSize(file); + char* const buffer = new char[file_size]; + + size_t bytes_last_read = 0; // # of bytes read in the last fread() + size_t bytes_read = 0; // # of bytes read so far + + fseek(file, 0, SEEK_SET); + + // Keeps reading the file until we cannot read further or the + // pre-determined file size is reached. + do { + bytes_last_read = fread(buffer+bytes_read, 1, file_size-bytes_read, file); + bytes_read += bytes_last_read; + } while (bytes_last_read > 0 && bytes_read < file_size); + + const String content(buffer, bytes_read); + delete[] buffer; + + return content; +} + +# ifdef _MSC_VER +# pragma warning(pop) +# endif // _MSC_VER + +static CapturedStream* g_captured_stderr = NULL; +static CapturedStream* g_captured_stdout = NULL; + +// Starts capturing an output stream (stdout/stderr). +void CaptureStream(int fd, const char* stream_name, CapturedStream** stream) { + if (*stream != NULL) { + GTEST_LOG_(FATAL) << "Only one " << stream_name + << " capturer can exist at a time."; + } + *stream = new CapturedStream(fd); +} + +// Stops capturing the output stream and returns the captured string. +String GetCapturedStream(CapturedStream** captured_stream) { + const String content = (*captured_stream)->GetCapturedString(); + + delete *captured_stream; + *captured_stream = NULL; + + return content; +} + +// Starts capturing stdout. +void CaptureStdout() { + CaptureStream(kStdOutFileno, "stdout", &g_captured_stdout); +} + +// Starts capturing stderr. +void CaptureStderr() { + CaptureStream(kStdErrFileno, "stderr", &g_captured_stderr); +} + +// Stops capturing stdout and returns the captured string. +String GetCapturedStdout() { return GetCapturedStream(&g_captured_stdout); } + +// Stops capturing stderr and returns the captured string. +String GetCapturedStderr() { return GetCapturedStream(&g_captured_stderr); } + +#endif // GTEST_HAS_STREAM_REDIRECTION + +#if GTEST_HAS_DEATH_TEST + +// A copy of all command line arguments. Set by InitGoogleTest(). +::std::vector g_argvs; + +// Returns the command line as a vector of strings. +const ::std::vector& GetArgvs() { return g_argvs; } + +#endif // GTEST_HAS_DEATH_TEST + +#if GTEST_OS_WINDOWS_MOBILE +namespace posix { +void Abort() { + DebugBreak(); + TerminateProcess(GetCurrentProcess(), 1); +} +} // namespace posix +#endif // GTEST_OS_WINDOWS_MOBILE + +// Returns the name of the environment variable corresponding to the +// given flag. For example, FlagToEnvVar("foo") will return +// "GTEST_FOO" in the open-source version. +static String FlagToEnvVar(const char* flag) { + const String full_flag = + (Message() << GTEST_FLAG_PREFIX_ << flag).GetString(); + + Message env_var; + for (size_t i = 0; i != full_flag.length(); i++) { + env_var << ToUpper(full_flag.c_str()[i]); + } + + return env_var.GetString(); +} + +// Parses 'str' for a 32-bit signed integer. If successful, writes +// the result to *value and returns true; otherwise leaves *value +// unchanged and returns false. +bool ParseInt32(const Message& src_text, const char* str, Int32* value) { + // Parses the environment variable as a decimal integer. + char* end = NULL; + const long long_value = strtol(str, &end, 10); // NOLINT + + // Has strtol() consumed all characters in the string? + if (*end != '\0') { + // No - an invalid character was encountered. + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value \"" << str << "\".\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + // Is the parsed value in the range of an Int32? + const Int32 result = static_cast(long_value); + if (long_value == LONG_MAX || long_value == LONG_MIN || + // The parsed value overflows as a long. (strtol() returns + // LONG_MAX or LONG_MIN when the input overflows.) + result != long_value + // The parsed value overflows as an Int32. + ) { + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value " << str << ", which overflows.\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + *value = result; + return true; +} + +// Reads and returns the Boolean environment variable corresponding to +// the given flag; if it's not set, returns default_value. +// +// The value is considered true iff it's not "0". +bool BoolFromGTestEnv(const char* flag, bool default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + return string_value == NULL ? + default_value : strcmp(string_value, "0") != 0; +} + +// Reads and returns a 32-bit integer stored in the environment +// variable corresponding to the given flag; if it isn't set or +// doesn't represent a valid 32-bit integer, returns default_value. +Int32 Int32FromGTestEnv(const char* flag, Int32 default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + if (string_value == NULL) { + // The environment variable is not set. + return default_value; + } + + Int32 result = default_value; + if (!ParseInt32(Message() << "Environment variable " << env_var, + string_value, &result)) { + printf("The default value %s is used.\n", + (Message() << default_value).GetString().c_str()); + fflush(stdout); + return default_value; + } + + return result; +} + +// Reads and returns the string environment variable corresponding to +// the given flag; if it's not set, returns default_value. +const char* StringFromGTestEnv(const char* flag, const char* default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const value = posix::GetEnv(env_var.c_str()); + return value == NULL ? default_value : value; +} + +} // namespace internal +} // namespace testing +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Google Test - The Google C++ Testing Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// It uses the << operator when possible, and prints the bytes in the +// object otherwise. A user can override its behavior for a class +// type Foo by defining either operator<<(::std::ostream&, const Foo&) +// or void PrintTo(const Foo&, ::std::ostream*) in the namespace that +// defines Foo. + +#include +#include +#include // NOLINT +#include + +namespace testing { + +namespace { + +using ::std::ostream; + +#if GTEST_OS_WINDOWS_MOBILE // Windows CE does not define _snprintf_s. +# define snprintf _snprintf +#elif _MSC_VER >= 1400 // VC 8.0 and later deprecate snprintf and _snprintf. +# define snprintf _snprintf_s +#elif _MSC_VER +# define snprintf _snprintf +#endif // GTEST_OS_WINDOWS_MOBILE + +// Prints a segment of bytes in the given object. +void PrintByteSegmentInObjectTo(const unsigned char* obj_bytes, size_t start, + size_t count, ostream* os) { + char text[5] = ""; + for (size_t i = 0; i != count; i++) { + const size_t j = start + i; + if (i != 0) { + // Organizes the bytes into groups of 2 for easy parsing by + // human. + if ((j % 2) == 0) + *os << ' '; + else + *os << '-'; + } + snprintf(text, sizeof(text), "%02X", obj_bytes[j]); + *os << text; + } +} + +// Prints the bytes in the given value to the given ostream. +void PrintBytesInObjectToImpl(const unsigned char* obj_bytes, size_t count, + ostream* os) { + // Tells the user how big the object is. + *os << count << "-byte object <"; + + const size_t kThreshold = 132; + const size_t kChunkSize = 64; + // If the object size is bigger than kThreshold, we'll have to omit + // some details by printing only the first and the last kChunkSize + // bytes. + // TODO(wan): let the user control the threshold using a flag. + if (count < kThreshold) { + PrintByteSegmentInObjectTo(obj_bytes, 0, count, os); + } else { + PrintByteSegmentInObjectTo(obj_bytes, 0, kChunkSize, os); + *os << " ... "; + // Rounds up to 2-byte boundary. + const size_t resume_pos = (count - kChunkSize + 1)/2*2; + PrintByteSegmentInObjectTo(obj_bytes, resume_pos, count - resume_pos, os); + } + *os << ">"; +} + +} // namespace + +namespace internal2 { + +// Delegates to PrintBytesInObjectToImpl() to print the bytes in the +// given object. The delegation simplifies the implementation, which +// uses the << operator and thus is easier done outside of the +// ::testing::internal namespace, which contains a << operator that +// sometimes conflicts with the one in STL. +void PrintBytesInObjectTo(const unsigned char* obj_bytes, size_t count, + ostream* os) { + PrintBytesInObjectToImpl(obj_bytes, count, os); +} + +} // namespace internal2 + +namespace internal { + +// Depending on the value of a char (or wchar_t), we print it in one +// of three formats: +// - as is if it's a printable ASCII (e.g. 'a', '2', ' '), +// - as a hexidecimal escape sequence (e.g. '\x7F'), or +// - as a special escape sequence (e.g. '\r', '\n'). +enum CharFormat { + kAsIs, + kHexEscape, + kSpecialEscape +}; + +// Returns true if c is a printable ASCII character. We test the +// value of c directly instead of calling isprint(), which is buggy on +// Windows Mobile. +inline bool IsPrintableAscii(wchar_t c) { + return 0x20 <= c && c <= 0x7E; +} + +// Prints a wide or narrow char c as a character literal without the +// quotes, escaping it when necessary; returns how c was formatted. +// The template argument UnsignedChar is the unsigned version of Char, +// which is the type of c. +template +static CharFormat PrintAsCharLiteralTo(Char c, ostream* os) { + switch (static_cast(c)) { + case L'\0': + *os << "\\0"; + break; + case L'\'': + *os << "\\'"; + break; + case L'\\': + *os << "\\\\"; + break; + case L'\a': + *os << "\\a"; + break; + case L'\b': + *os << "\\b"; + break; + case L'\f': + *os << "\\f"; + break; + case L'\n': + *os << "\\n"; + break; + case L'\r': + *os << "\\r"; + break; + case L'\t': + *os << "\\t"; + break; + case L'\v': + *os << "\\v"; + break; + default: + if (IsPrintableAscii(c)) { + *os << static_cast(c); + return kAsIs; + } else { + *os << String::Format("\\x%X", static_cast(c)); + return kHexEscape; + } + } + return kSpecialEscape; +} + +// Prints a char c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsWideStringLiteralTo(wchar_t c, ostream* os) { + switch (c) { + case L'\'': + *os << "'"; + return kAsIs; + case L'"': + *os << "\\\""; + return kSpecialEscape; + default: + return PrintAsCharLiteralTo(c, os); + } +} + +// Prints a char c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsNarrowStringLiteralTo(char c, ostream* os) { + return PrintAsWideStringLiteralTo(static_cast(c), os); +} + +// Prints a wide or narrow character c and its code. '\0' is printed +// as "'\\0'", other unprintable characters are also properly escaped +// using the standard C++ escape sequence. The template argument +// UnsignedChar is the unsigned version of Char, which is the type of c. +template +void PrintCharAndCodeTo(Char c, ostream* os) { + // First, print c as a literal in the most readable form we can find. + *os << ((sizeof(c) > 1) ? "L'" : "'"); + const CharFormat format = PrintAsCharLiteralTo(c, os); + *os << "'"; + + // To aid user debugging, we also print c's code in decimal, unless + // it's 0 (in which case c was printed as '\\0', making the code + // obvious). + if (c == 0) + return; + *os << " (" << String::Format("%d", c).c_str(); + + // For more convenience, we print c's code again in hexidecimal, + // unless c was already printed in the form '\x##' or the code is in + // [1, 9]. + if (format == kHexEscape || (1 <= c && c <= 9)) { + // Do nothing. + } else { + *os << String::Format(", 0x%X", + static_cast(c)).c_str(); + } + *os << ")"; +} + +void PrintTo(unsigned char c, ::std::ostream* os) { + PrintCharAndCodeTo(c, os); +} +void PrintTo(signed char c, ::std::ostream* os) { + PrintCharAndCodeTo(c, os); +} + +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its code. L'\0' is printed as "L'\\0'". +void PrintTo(wchar_t wc, ostream* os) { + PrintCharAndCodeTo(wc, os); +} + +// Prints the given array of characters to the ostream. +// The array starts at *begin, the length is len, it may include '\0' characters +// and may not be null-terminated. +static void PrintCharsAsStringTo(const char* begin, size_t len, ostream* os) { + *os << "\""; + bool is_previous_hex = false; + for (size_t index = 0; index < len; ++index) { + const char cur = begin[index]; + if (is_previous_hex && IsXDigit(cur)) { + // Previous character is of '\x..' form and this character can be + // interpreted as another hexadecimal digit in its number. Break string to + // disambiguate. + *os << "\" \""; + } + is_previous_hex = PrintAsNarrowStringLiteralTo(cur, os) == kHexEscape; + } + *os << "\""; +} + +// Prints a (const) char array of 'len' elements, starting at address 'begin'. +void UniversalPrintArray(const char* begin, size_t len, ostream* os) { + PrintCharsAsStringTo(begin, len, os); +} + +// Prints the given array of wide characters to the ostream. +// The array starts at *begin, the length is len, it may include L'\0' +// characters and may not be null-terminated. +static void PrintWideCharsAsStringTo(const wchar_t* begin, size_t len, + ostream* os) { + *os << "L\""; + bool is_previous_hex = false; + for (size_t index = 0; index < len; ++index) { + const wchar_t cur = begin[index]; + if (is_previous_hex && isascii(cur) && IsXDigit(static_cast(cur))) { + // Previous character is of '\x..' form and this character can be + // interpreted as another hexadecimal digit in its number. Break string to + // disambiguate. + *os << "\" L\""; + } + is_previous_hex = PrintAsWideStringLiteralTo(cur, os) == kHexEscape; + } + *os << "\""; +} + +// Prints the given C string to the ostream. +void PrintTo(const char* s, ostream* os) { + if (s == NULL) { + *os << "NULL"; + } else { + *os << ImplicitCast_(s) << " pointing to "; + PrintCharsAsStringTo(s, strlen(s), os); + } +} + +// MSVC compiler can be configured to define whar_t as a typedef +// of unsigned short. Defining an overload for const wchar_t* in that case +// would cause pointers to unsigned shorts be printed as wide strings, +// possibly accessing more memory than intended and causing invalid +// memory accesses. MSVC defines _NATIVE_WCHAR_T_DEFINED symbol when +// wchar_t is implemented as a native type. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Prints the given wide C string to the ostream. +void PrintTo(const wchar_t* s, ostream* os) { + if (s == NULL) { + *os << "NULL"; + } else { + *os << ImplicitCast_(s) << " pointing to "; + PrintWideCharsAsStringTo(s, wcslen(s), os); + } +} +#endif // wchar_t is native + +// Prints a ::string object. +#if GTEST_HAS_GLOBAL_STRING +void PrintStringTo(const ::string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_GLOBAL_STRING + +void PrintStringTo(const ::std::string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} + +// Prints a ::wstring object. +#if GTEST_HAS_GLOBAL_WSTRING +void PrintWideStringTo(const ::wstring& s, ostream* os) { + PrintWideCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +#if GTEST_HAS_STD_WSTRING +void PrintWideStringTo(const ::std::wstring& s, ostream* os) { + PrintWideCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_STD_WSTRING + +} // namespace internal + +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// +// The Google C++ Testing Framework (Google Test) + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { + +using internal::GetUnitTestImpl; + +// Gets the summary of the failure message by omitting the stack trace +// in it. +internal::String TestPartResult::ExtractSummary(const char* message) { + const char* const stack_trace = strstr(message, internal::kStackTraceMarker); + return stack_trace == NULL ? internal::String(message) : + internal::String(message, stack_trace - message); +} + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result) { + return os + << result.file_name() << ":" << result.line_number() << ": " + << (result.type() == TestPartResult::kSuccess ? "Success" : + result.type() == TestPartResult::kFatalFailure ? "Fatal failure" : + "Non-fatal failure") << ":\n" + << result.message() << std::endl; +} + +// Appends a TestPartResult to the array. +void TestPartResultArray::Append(const TestPartResult& result) { + array_.push_back(result); +} + +// Returns the TestPartResult at the given index (0-based). +const TestPartResult& TestPartResultArray::GetTestPartResult(int index) const { + if (index < 0 || index >= size()) { + printf("\nInvalid index (%d) into TestPartResultArray.\n", index); + internal::posix::Abort(); + } + + return array_[index]; +} + +// Returns the number of TestPartResult objects in the array. +int TestPartResultArray::size() const { + return static_cast(array_.size()); +} + +namespace internal { + +HasNewFatalFailureHelper::HasNewFatalFailureHelper() + : has_new_fatal_failure_(false), + original_reporter_(GetUnitTestImpl()-> + GetTestPartResultReporterForCurrentThread()) { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread(this); +} + +HasNewFatalFailureHelper::~HasNewFatalFailureHelper() { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread( + original_reporter_); +} + +void HasNewFatalFailureHelper::ReportTestPartResult( + const TestPartResult& result) { + if (result.fatally_failed()) + has_new_fatal_failure_ = true; + original_reporter_->ReportTestPartResult(result); +} + +} // namespace internal + +} // namespace testing +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + + +namespace testing { +namespace internal { + +#if GTEST_HAS_TYPED_TEST_P + +// Skips to the first non-space char in str. Returns an empty string if str +// contains only whitespace characters. +static const char* SkipSpaces(const char* str) { + while (IsSpace(*str)) + str++; + return str; +} + +// Verifies that registered_tests match the test names in +// defined_test_names_; returns registered_tests if successful, or +// aborts the program otherwise. +const char* TypedTestCasePState::VerifyRegisteredTestNames( + const char* file, int line, const char* registered_tests) { + typedef ::std::set::const_iterator DefinedTestIter; + registered_ = true; + + // Skip initial whitespace in registered_tests since some + // preprocessors prefix stringizied literals with whitespace. + registered_tests = SkipSpaces(registered_tests); + + Message errors; + ::std::set tests; + for (const char* names = registered_tests; names != NULL; + names = SkipComma(names)) { + const String name = GetPrefixUntilComma(names); + if (tests.count(name) != 0) { + errors << "Test " << name << " is listed more than once.\n"; + continue; + } + + bool found = false; + for (DefinedTestIter it = defined_test_names_.begin(); + it != defined_test_names_.end(); + ++it) { + if (name == *it) { + found = true; + break; + } + } + + if (found) { + tests.insert(name); + } else { + errors << "No test named " << name + << " can be found in this test case.\n"; + } + } + + for (DefinedTestIter it = defined_test_names_.begin(); + it != defined_test_names_.end(); + ++it) { + if (tests.count(*it) == 0) { + errors << "You forgot to list test " << *it << ".\n"; + } + } + + const String& errors_str = errors.GetString(); + if (errors_str != "") { + fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), + errors_str.c_str()); + fflush(stderr); + posix::Abort(); + } + + return registered_tests; +} + +#endif // GTEST_HAS_TYPED_TEST_P + +} // namespace internal +} // namespace testing diff --git a/caffe-crfrnn/src/gtest/gtest.h b/caffe-crfrnn/src/gtest/gtest.h new file mode 100644 index 00000000..3143bd67 --- /dev/null +++ b/caffe-crfrnn/src/gtest/gtest.h @@ -0,0 +1,19537 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the public API for Google Test. It should be +// included by any test program that uses Google Test. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! +// +// Acknowledgment: Google Test borrowed the idea of automatic test +// registration from Barthelemy Dagenais' (barthelemy@prologique.com) +// easyUnit framework. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_H_ + +#include +#include + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file declares functions and macros used internally by +// Google Test. They are subject to change without notice. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan) +// +// Low-level types and utilities for porting Google Test to various +// platforms. They are subject to change without notice. DO NOT USE +// THEM IN USER CODE. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ + +// The user can define the following macros in the build script to +// control Google Test's behavior. If the user doesn't define a macro +// in this list, Google Test will define it. +// +// GTEST_HAS_CLONE - Define it to 1/0 to indicate that clone(2) +// is/isn't available. +// GTEST_HAS_EXCEPTIONS - Define it to 1/0 to indicate that exceptions +// are enabled. +// GTEST_HAS_GLOBAL_STRING - Define it to 1/0 to indicate that ::string +// is/isn't available (some systems define +// ::string, which is different to std::string). +// GTEST_HAS_GLOBAL_WSTRING - Define it to 1/0 to indicate that ::string +// is/isn't available (some systems define +// ::wstring, which is different to std::wstring). +// GTEST_HAS_POSIX_RE - Define it to 1/0 to indicate that POSIX regular +// expressions are/aren't available. +// GTEST_HAS_PTHREAD - Define it to 1/0 to indicate that +// is/isn't available. +// GTEST_HAS_RTTI - Define it to 1/0 to indicate that RTTI is/isn't +// enabled. +// GTEST_HAS_STD_WSTRING - Define it to 1/0 to indicate that +// std::wstring does/doesn't work (Google Test can +// be used where std::wstring is unavailable). +// GTEST_HAS_TR1_TUPLE - Define it to 1/0 to indicate tr1::tuple +// is/isn't available. +// GTEST_HAS_SEH - Define it to 1/0 to indicate whether the +// compiler supports Microsoft's "Structured +// Exception Handling". +// GTEST_HAS_STREAM_REDIRECTION +// - Define it to 1/0 to indicate whether the +// platform supports I/O stream redirection using +// dup() and dup2(). +// GTEST_USE_OWN_TR1_TUPLE - Define it to 1/0 to indicate whether Google +// Test's own tr1 tuple implementation should be +// used. Unused when the user sets +// GTEST_HAS_TR1_TUPLE to 0. +// GTEST_LINKED_AS_SHARED_LIBRARY +// - Define to 1 when compiling tests that use +// Google Test as a shared library (known as +// DLL on Windows). +// GTEST_CREATE_SHARED_LIBRARY +// - Define to 1 when compiling Google Test itself +// as a shared library. + +// This header defines the following utilities: +// +// Macros indicating the current platform (defined to 1 if compiled on +// the given platform; otherwise undefined): +// GTEST_OS_AIX - IBM AIX +// GTEST_OS_CYGWIN - Cygwin +// GTEST_OS_HPUX - HP-UX +// GTEST_OS_LINUX - Linux +// GTEST_OS_LINUX_ANDROID - Google Android +// GTEST_OS_MAC - Mac OS X +// GTEST_OS_NACL - Google Native Client (NaCl) +// GTEST_OS_SOLARIS - Sun Solaris +// GTEST_OS_SYMBIAN - Symbian +// GTEST_OS_WINDOWS - Windows (Desktop, MinGW, or Mobile) +// GTEST_OS_WINDOWS_DESKTOP - Windows Desktop +// GTEST_OS_WINDOWS_MINGW - MinGW +// GTEST_OS_WINDOWS_MOBILE - Windows Mobile +// GTEST_OS_ZOS - z/OS +// +// Among the platforms, Cygwin, Linux, Max OS X, and Windows have the +// most stable support. Since core members of the Google Test project +// don't have access to other platforms, support for them may be less +// stable. If you notice any problems on your platform, please notify +// googletestframework@googlegroups.com (patches for fixing them are +// even more welcome!). +// +// Note that it is possible that none of the GTEST_OS_* macros are defined. +// +// Macros indicating available Google Test features (defined to 1 if +// the corresponding feature is supported; otherwise undefined): +// GTEST_HAS_COMBINE - the Combine() function (for value-parameterized +// tests) +// GTEST_HAS_DEATH_TEST - death tests +// GTEST_HAS_PARAM_TEST - value-parameterized tests +// GTEST_HAS_TYPED_TEST - typed tests +// GTEST_HAS_TYPED_TEST_P - type-parameterized tests +// GTEST_USES_POSIX_RE - enhanced POSIX regex is used. Do not confuse with +// GTEST_HAS_POSIX_RE (see above) which users can +// define themselves. +// GTEST_USES_SIMPLE_RE - our own simple regex is used; +// the above two are mutually exclusive. +// GTEST_CAN_COMPARE_NULL - accepts untyped NULL in EXPECT_EQ(). +// +// Macros for basic C++ coding: +// GTEST_AMBIGUOUS_ELSE_BLOCKER_ - for disabling a gcc warning. +// GTEST_ATTRIBUTE_UNUSED_ - declares that a class' instances or a +// variable don't have to be used. +// GTEST_DISALLOW_ASSIGN_ - disables operator=. +// GTEST_DISALLOW_COPY_AND_ASSIGN_ - disables copy ctor and operator=. +// GTEST_MUST_USE_RESULT_ - declares that a function's result must be used. +// +// Synchronization: +// Mutex, MutexLock, ThreadLocal, GetThreadCount() +// - synchronization primitives. +// GTEST_IS_THREADSAFE - defined to 1 to indicate that the above +// synchronization primitives have real implementations +// and Google Test is thread-safe; or 0 otherwise. +// +// Template meta programming: +// is_pointer - as in TR1; needed on Symbian and IBM XL C/C++ only. +// IteratorTraits - partial implementation of std::iterator_traits, which +// is not available in libCstd when compiled with Sun C++. +// +// Smart pointers: +// scoped_ptr - as in TR2. +// +// Regular expressions: +// RE - a simple regular expression class using the POSIX +// Extended Regular Expression syntax on UNIX-like +// platforms, or a reduced regular exception syntax on +// other platforms, including Windows. +// +// Logging: +// GTEST_LOG_() - logs messages at the specified severity level. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. +// +// Stdout and stderr capturing: +// CaptureStdout() - starts capturing stdout. +// GetCapturedStdout() - stops capturing stdout and returns the captured +// string. +// CaptureStderr() - starts capturing stderr. +// GetCapturedStderr() - stops capturing stderr and returns the captured +// string. +// +// Integer types: +// TypeWithSize - maps an integer to a int type. +// Int32, UInt32, Int64, UInt64, TimeInMillis +// - integers of known sizes. +// BiggestInt - the biggest signed integer type. +// +// Command-line utilities: +// GTEST_FLAG() - references a flag. +// GTEST_DECLARE_*() - declares a flag. +// GTEST_DEFINE_*() - defines a flag. +// GetArgvs() - returns the command line as a vector of strings. +// +// Environment variable utilities: +// GetEnv() - gets the value of an environment variable. +// BoolFromGTestEnv() - parses a bool environment variable. +// Int32FromGTestEnv() - parses an Int32 environment variable. +// StringFromGTestEnv() - parses a string environment variable. + +#include // for isspace, etc +#include // for ptrdiff_t +#include +#include +#include +#ifndef _WIN32_WCE +# include +# include +#endif // !_WIN32_WCE + +#include // NOLINT +#include // NOLINT +#include // NOLINT + +#define GTEST_DEV_EMAIL_ "googletestframework@@googlegroups.com" +#define GTEST_FLAG_PREFIX_ "gtest_" +#define GTEST_FLAG_PREFIX_DASH_ "gtest-" +#define GTEST_FLAG_PREFIX_UPPER_ "GTEST_" +#define GTEST_NAME_ "Google Test" +#define GTEST_PROJECT_URL_ "http://code.google.com/p/googletest/" + +// Determines the version of gcc that is used to compile this. +#ifdef __GNUC__ +// 40302 means version 4.3.2. +# define GTEST_GCC_VER_ \ + (__GNUC__*10000 + __GNUC_MINOR__*100 + __GNUC_PATCHLEVEL__) +#endif // __GNUC__ + +// Determines the platform on which Google Test is compiled. +#ifdef __CYGWIN__ +# define GTEST_OS_CYGWIN 1 +#elif defined __SYMBIAN32__ +# define GTEST_OS_SYMBIAN 1 +#elif defined _WIN32 +# define GTEST_OS_WINDOWS 1 +# ifdef _WIN32_WCE +# define GTEST_OS_WINDOWS_MOBILE 1 +# elif defined(__MINGW__) || defined(__MINGW32__) +# define GTEST_OS_WINDOWS_MINGW 1 +# else +# define GTEST_OS_WINDOWS_DESKTOP 1 +# endif // _WIN32_WCE +#elif defined __APPLE__ +# define GTEST_OS_MAC 1 +#elif defined __linux__ +# define GTEST_OS_LINUX 1 +# ifdef ANDROID +# define GTEST_OS_LINUX_ANDROID 1 +# endif // ANDROID +#elif defined __MVS__ +# define GTEST_OS_ZOS 1 +#elif defined(__sun) && defined(__SVR4) +# define GTEST_OS_SOLARIS 1 +#elif defined(_AIX) +# define GTEST_OS_AIX 1 +#elif defined(__hpux) +# define GTEST_OS_HPUX 1 +#elif defined __native_client__ +# define GTEST_OS_NACL 1 +#endif // __CYGWIN__ + +// Brings in definitions for functions used in the testing::internal::posix +// namespace (read, write, close, chdir, isatty, stat). We do not currently +// use them on Windows Mobile. +#if !GTEST_OS_WINDOWS +// This assumes that non-Windows OSes provide unistd.h. For OSes where this +// is not the case, we need to include headers that provide the functions +// mentioned above. +# include +# if !GTEST_OS_NACL +// TODO(vladl@google.com): Remove this condition when Native Client SDK adds +// strings.h (tracked in +// http://code.google.com/p/nativeclient/issues/detail?id=1175). +# include // Native Client doesn't provide strings.h. +# endif +#elif !GTEST_OS_WINDOWS_MOBILE +# include +# include +#endif + +// Defines this to true iff Google Test can use POSIX regular expressions. +#ifndef GTEST_HAS_POSIX_RE +# define GTEST_HAS_POSIX_RE (!GTEST_OS_WINDOWS) +#endif + +#if GTEST_HAS_POSIX_RE + +// On some platforms, needs someone to define size_t, and +// won't compile otherwise. We can #include it here as we already +// included , which is guaranteed to define size_t through +// . +# include // NOLINT + +# define GTEST_USES_POSIX_RE 1 + +#elif GTEST_OS_WINDOWS + +// is not available on Windows. Use our own simple regex +// implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#else + +// may not be available on this platform. Use our own +// simple regex implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#endif // GTEST_HAS_POSIX_RE + +#ifndef GTEST_HAS_EXCEPTIONS +// The user didn't tell us whether exceptions are enabled, so we need +// to figure it out. +# if defined(_MSC_VER) || defined(__BORLANDC__) +// MSVC's and C++Builder's implementations of the STL use the _HAS_EXCEPTIONS +// macro to enable exceptions, so we'll do the same. +// Assumes that exceptions are enabled by default. +# ifndef _HAS_EXCEPTIONS +# define _HAS_EXCEPTIONS 1 +# endif // _HAS_EXCEPTIONS +# define GTEST_HAS_EXCEPTIONS _HAS_EXCEPTIONS +# elif defined(__GNUC__) && __EXCEPTIONS +// gcc defines __EXCEPTIONS to 1 iff exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__SUNPRO_CC) +// Sun Pro CC supports exceptions. However, there is no compile-time way of +// detecting whether they are enabled or not. Therefore, we assume that +// they are enabled unless the user tells us otherwise. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__IBMCPP__) && __EXCEPTIONS +// xlC defines __EXCEPTIONS to 1 iff exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__HP_aCC) +// Exception handling is in effect by default in HP aCC compiler. It has to +// be turned of by +noeh compiler option if desired. +# define GTEST_HAS_EXCEPTIONS 1 +# else +// For other compilers, we assume exceptions are disabled to be +// conservative. +# define GTEST_HAS_EXCEPTIONS 0 +# endif // defined(_MSC_VER) || defined(__BORLANDC__) +#endif // GTEST_HAS_EXCEPTIONS + +#if !defined(GTEST_HAS_STD_STRING) +// Even though we don't use this macro any longer, we keep it in case +// some clients still depend on it. +# define GTEST_HAS_STD_STRING 1 +#elif !GTEST_HAS_STD_STRING +// The user told us that ::std::string isn't available. +# error "Google Test cannot be used where ::std::string isn't available." +#endif // !defined(GTEST_HAS_STD_STRING) + +#ifndef GTEST_HAS_GLOBAL_STRING +// The user didn't tell us whether ::string is available, so we need +// to figure it out. + +# define GTEST_HAS_GLOBAL_STRING 0 + +#endif // GTEST_HAS_GLOBAL_STRING + +#ifndef GTEST_HAS_STD_WSTRING +// The user didn't tell us whether ::std::wstring is available, so we need +// to figure it out. +// TODO(wan@google.com): uses autoconf to detect whether ::std::wstring +// is available. + +// Cygwin 1.7 and below doesn't support ::std::wstring. +// Solaris' libc++ doesn't support it either. Android has +// no support for it at least as recent as Froyo (2.2). +# define GTEST_HAS_STD_WSTRING \ + (!(GTEST_OS_LINUX_ANDROID || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS)) + +#endif // GTEST_HAS_STD_WSTRING + +#ifndef GTEST_HAS_GLOBAL_WSTRING +// The user didn't tell us whether ::wstring is available, so we need +// to figure it out. +# define GTEST_HAS_GLOBAL_WSTRING \ + (GTEST_HAS_STD_WSTRING && GTEST_HAS_GLOBAL_STRING) +#endif // GTEST_HAS_GLOBAL_WSTRING + +// Determines whether RTTI is available. +#ifndef GTEST_HAS_RTTI +// The user didn't tell us whether RTTI is enabled, so we need to +// figure it out. + +# ifdef _MSC_VER + +# ifdef _CPPRTTI // MSVC defines this macro iff RTTI is enabled. +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +// Starting with version 4.3.2, gcc defines __GXX_RTTI iff RTTI is enabled. +# elif defined(__GNUC__) && (GTEST_GCC_VER_ >= 40302) + +# ifdef __GXX_RTTI +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif // __GXX_RTTI + +// Starting with version 9.0 IBM Visual Age defines __RTTI_ALL__ to 1 if +// both the typeid and dynamic_cast features are present. +# elif defined(__IBMCPP__) && (__IBMCPP__ >= 900) + +# ifdef __RTTI_ALL__ +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +# else + +// For all other compilers, we assume RTTI is enabled. +# define GTEST_HAS_RTTI 1 + +# endif // _MSC_VER + +#endif // GTEST_HAS_RTTI + +// It's this header's responsibility to #include when RTTI +// is enabled. +#if GTEST_HAS_RTTI +# include +#endif + +// Determines whether Google Test can use the pthreads library. +#ifndef GTEST_HAS_PTHREAD +// The user didn't tell us explicitly, so we assume pthreads support is +// available on Linux and Mac. +// +// To disable threading support in Google Test, add -DGTEST_HAS_PTHREAD=0 +// to your compiler flags. +# define GTEST_HAS_PTHREAD (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_HPUX) +#endif // GTEST_HAS_PTHREAD + +#if GTEST_HAS_PTHREAD +// gtest-port.h guarantees to #include when GTEST_HAS_PTHREAD is +// true. +# include // NOLINT + +// For timespec and nanosleep, used below. +# include // NOLINT +#endif + +// Determines whether Google Test can use tr1/tuple. You can define +// this macro to 0 to prevent Google Test from using tuple (any +// feature depending on tuple with be disabled in this mode). +#ifndef GTEST_HAS_TR1_TUPLE +// The user didn't tell us not to do it, so we assume it's OK. +# define GTEST_HAS_TR1_TUPLE 1 +#endif // GTEST_HAS_TR1_TUPLE + +// Determines whether Google Test's own tr1 tuple implementation +// should be used. +#ifndef GTEST_USE_OWN_TR1_TUPLE +// The user didn't tell us, so we need to figure it out. + +// We use our own TR1 tuple if we aren't sure the user has an +// implementation of it already. At this time, GCC 4.0.0+ and MSVC +// 2010 are the only mainstream compilers that come with a TR1 tuple +// implementation. NVIDIA's CUDA NVCC compiler pretends to be GCC by +// defining __GNUC__ and friends, but cannot compile GCC's tuple +// implementation. MSVC 2008 (9.0) provides TR1 tuple in a 323 MB +// Feature Pack download, which we cannot assume the user has. +# if (defined(__GNUC__) && !defined(__CUDACC__) && (GTEST_GCC_VER_ >= 40000)) \ + || _MSC_VER >= 1600 +# define GTEST_USE_OWN_TR1_TUPLE 0 +# else +# define GTEST_USE_OWN_TR1_TUPLE 1 +# endif + +#endif // GTEST_USE_OWN_TR1_TUPLE + +// To avoid conditional compilation everywhere, we make it +// gtest-port.h's responsibility to #include the header implementing +// tr1/tuple. +#if GTEST_HAS_TR1_TUPLE + +# if GTEST_USE_OWN_TR1_TUPLE +// This file was GENERATED by a script. DO NOT EDIT BY HAND!!! + +// Copyright 2009 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Implements a subset of TR1 tuple needed by Google Test and Google Mock. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ + +#include // For ::std::pair. + +// The compiler used in Symbian has a bug that prevents us from declaring the +// tuple template as a friend (it complains that tuple is redefined). This +// hack bypasses the bug by declaring the members that should otherwise be +// private as public. +// Sun Studio versions < 12 also have the above bug. +#if defined(__SYMBIAN32__) || (defined(__SUNPRO_CC) && __SUNPRO_CC < 0x590) +# define GTEST_DECLARE_TUPLE_AS_FRIEND_ public: +#else +# define GTEST_DECLARE_TUPLE_AS_FRIEND_ \ + template friend class tuple; \ + private: +#endif + +// GTEST_n_TUPLE_(T) is the type of an n-tuple. +#define GTEST_0_TUPLE_(T) tuple<> +#define GTEST_1_TUPLE_(T) tuple +#define GTEST_2_TUPLE_(T) tuple +#define GTEST_3_TUPLE_(T) tuple +#define GTEST_4_TUPLE_(T) tuple +#define GTEST_5_TUPLE_(T) tuple +#define GTEST_6_TUPLE_(T) tuple +#define GTEST_7_TUPLE_(T) tuple +#define GTEST_8_TUPLE_(T) tuple +#define GTEST_9_TUPLE_(T) tuple +#define GTEST_10_TUPLE_(T) tuple + +// GTEST_n_TYPENAMES_(T) declares a list of n typenames. +#define GTEST_0_TYPENAMES_(T) +#define GTEST_1_TYPENAMES_(T) typename T##0 +#define GTEST_2_TYPENAMES_(T) typename T##0, typename T##1 +#define GTEST_3_TYPENAMES_(T) typename T##0, typename T##1, typename T##2 +#define GTEST_4_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3 +#define GTEST_5_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4 +#define GTEST_6_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5 +#define GTEST_7_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6 +#define GTEST_8_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, typename T##7 +#define GTEST_9_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, \ + typename T##7, typename T##8 +#define GTEST_10_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, \ + typename T##7, typename T##8, typename T##9 + +// In theory, defining stuff in the ::std namespace is undefined +// behavior. We can do this as we are playing the role of a standard +// library vendor. +namespace std { +namespace tr1 { + +template +class tuple; + +// Anything in namespace gtest_internal is Google Test's INTERNAL +// IMPLEMENTATION DETAIL and MUST NOT BE USED DIRECTLY in user code. +namespace gtest_internal { + +// ByRef::type is T if T is a reference; otherwise it's const T&. +template +struct ByRef { typedef const T& type; }; // NOLINT +template +struct ByRef { typedef T& type; }; // NOLINT + +// A handy wrapper for ByRef. +#define GTEST_BY_REF_(T) typename ::std::tr1::gtest_internal::ByRef::type + +// AddRef::type is T if T is a reference; otherwise it's T&. This +// is the same as tr1::add_reference::type. +template +struct AddRef { typedef T& type; }; // NOLINT +template +struct AddRef { typedef T& type; }; // NOLINT + +// A handy wrapper for AddRef. +#define GTEST_ADD_REF_(T) typename ::std::tr1::gtest_internal::AddRef::type + +// A helper for implementing get(). +template class Get; + +// A helper for implementing tuple_element. kIndexValid is true +// iff k < the number of fields in tuple type T. +template +struct TupleElement; + +template +struct TupleElement { typedef T0 type; }; + +template +struct TupleElement { typedef T1 type; }; + +template +struct TupleElement { typedef T2 type; }; + +template +struct TupleElement { typedef T3 type; }; + +template +struct TupleElement { typedef T4 type; }; + +template +struct TupleElement { typedef T5 type; }; + +template +struct TupleElement { typedef T6 type; }; + +template +struct TupleElement { typedef T7 type; }; + +template +struct TupleElement { typedef T8 type; }; + +template +struct TupleElement { typedef T9 type; }; + +} // namespace gtest_internal + +template <> +class tuple<> { + public: + tuple() {} + tuple(const tuple& /* t */) {} + tuple& operator=(const tuple& /* t */) { return *this; } +}; + +template +class GTEST_1_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0) : f0_(f0) {} + + tuple(const tuple& t) : f0_(t.f0_) {} + + template + tuple(const GTEST_1_TUPLE_(U)& t) : f0_(t.f0_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_1_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_1_TUPLE_(U)& t) { + f0_ = t.f0_; + return *this; + } + + T0 f0_; +}; + +template +class GTEST_2_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1) : f0_(f0), + f1_(f1) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_) {} + + template + tuple(const GTEST_2_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_) {} + template + tuple(const ::std::pair& p) : f0_(p.first), f1_(p.second) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_2_TUPLE_(U)& t) { + return CopyFrom(t); + } + template + tuple& operator=(const ::std::pair& p) { + f0_ = p.first; + f1_ = p.second; + return *this; + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_2_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + return *this; + } + + T0 f0_; + T1 f1_; +}; + +template +class GTEST_3_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2) : f0_(f0), f1_(f1), f2_(f2) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_) {} + + template + tuple(const GTEST_3_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_3_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_3_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; +}; + +template +class GTEST_4_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_) {} + + template + tuple(const GTEST_4_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_4_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_4_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; +}; + +template +class GTEST_5_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, + GTEST_BY_REF_(T4) f4) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_) {} + + template + tuple(const GTEST_5_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_5_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_5_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; +}; + +template +class GTEST_6_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_) {} + + template + tuple(const GTEST_6_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_6_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_6_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; +}; + +template +class GTEST_7_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3), f4_(f4), f5_(f5), f6_(f6) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_) {} + + template + tuple(const GTEST_7_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_7_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_7_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; +}; + +template +class GTEST_8_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, + GTEST_BY_REF_(T7) f7) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5), f6_(f6), f7_(f7) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_) {} + + template + tuple(const GTEST_8_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_8_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_8_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; +}; + +template +class GTEST_9_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_(), f8_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, GTEST_BY_REF_(T7) f7, + GTEST_BY_REF_(T8) f8) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5), f6_(f6), f7_(f7), f8_(f8) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_) {} + + template + tuple(const GTEST_9_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_9_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_9_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + f8_ = t.f8_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; + T8 f8_; +}; + +template +class tuple { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_(), f8_(), + f9_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, GTEST_BY_REF_(T7) f7, + GTEST_BY_REF_(T8) f8, GTEST_BY_REF_(T9) f9) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3), f4_(f4), f5_(f5), f6_(f6), f7_(f7), f8_(f8), f9_(f9) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_), f9_(t.f9_) {} + + template + tuple(const GTEST_10_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_), + f9_(t.f9_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_10_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_10_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + f8_ = t.f8_; + f9_ = t.f9_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; + T8 f8_; + T9 f9_; +}; + +// 6.1.3.2 Tuple creation functions. + +// Known limitations: we don't support passing an +// std::tr1::reference_wrapper to make_tuple(). And we don't +// implement tie(). + +inline tuple<> make_tuple() { return tuple<>(); } + +template +inline GTEST_1_TUPLE_(T) make_tuple(const T0& f0) { + return GTEST_1_TUPLE_(T)(f0); +} + +template +inline GTEST_2_TUPLE_(T) make_tuple(const T0& f0, const T1& f1) { + return GTEST_2_TUPLE_(T)(f0, f1); +} + +template +inline GTEST_3_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2) { + return GTEST_3_TUPLE_(T)(f0, f1, f2); +} + +template +inline GTEST_4_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3) { + return GTEST_4_TUPLE_(T)(f0, f1, f2, f3); +} + +template +inline GTEST_5_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4) { + return GTEST_5_TUPLE_(T)(f0, f1, f2, f3, f4); +} + +template +inline GTEST_6_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5) { + return GTEST_6_TUPLE_(T)(f0, f1, f2, f3, f4, f5); +} + +template +inline GTEST_7_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6) { + return GTEST_7_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6); +} + +template +inline GTEST_8_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7) { + return GTEST_8_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7); +} + +template +inline GTEST_9_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7, + const T8& f8) { + return GTEST_9_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7, f8); +} + +template +inline GTEST_10_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7, + const T8& f8, const T9& f9) { + return GTEST_10_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9); +} + +// 6.1.3.3 Tuple helper classes. + +template struct tuple_size; + +template +struct tuple_size { static const int value = 0; }; + +template +struct tuple_size { static const int value = 1; }; + +template +struct tuple_size { static const int value = 2; }; + +template +struct tuple_size { static const int value = 3; }; + +template +struct tuple_size { static const int value = 4; }; + +template +struct tuple_size { static const int value = 5; }; + +template +struct tuple_size { static const int value = 6; }; + +template +struct tuple_size { static const int value = 7; }; + +template +struct tuple_size { static const int value = 8; }; + +template +struct tuple_size { static const int value = 9; }; + +template +struct tuple_size { static const int value = 10; }; + +template +struct tuple_element { + typedef typename gtest_internal::TupleElement< + k < (tuple_size::value), k, Tuple>::type type; +}; + +#define GTEST_TUPLE_ELEMENT_(k, Tuple) typename tuple_element::type + +// 6.1.3.4 Element access. + +namespace gtest_internal { + +template <> +class Get<0> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(0, Tuple)) + Field(Tuple& t) { return t.f0_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(0, Tuple)) + ConstField(const Tuple& t) { return t.f0_; } +}; + +template <> +class Get<1> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(1, Tuple)) + Field(Tuple& t) { return t.f1_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(1, Tuple)) + ConstField(const Tuple& t) { return t.f1_; } +}; + +template <> +class Get<2> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(2, Tuple)) + Field(Tuple& t) { return t.f2_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(2, Tuple)) + ConstField(const Tuple& t) { return t.f2_; } +}; + +template <> +class Get<3> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(3, Tuple)) + Field(Tuple& t) { return t.f3_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(3, Tuple)) + ConstField(const Tuple& t) { return t.f3_; } +}; + +template <> +class Get<4> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(4, Tuple)) + Field(Tuple& t) { return t.f4_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(4, Tuple)) + ConstField(const Tuple& t) { return t.f4_; } +}; + +template <> +class Get<5> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(5, Tuple)) + Field(Tuple& t) { return t.f5_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(5, Tuple)) + ConstField(const Tuple& t) { return t.f5_; } +}; + +template <> +class Get<6> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(6, Tuple)) + Field(Tuple& t) { return t.f6_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(6, Tuple)) + ConstField(const Tuple& t) { return t.f6_; } +}; + +template <> +class Get<7> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(7, Tuple)) + Field(Tuple& t) { return t.f7_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(7, Tuple)) + ConstField(const Tuple& t) { return t.f7_; } +}; + +template <> +class Get<8> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(8, Tuple)) + Field(Tuple& t) { return t.f8_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(8, Tuple)) + ConstField(const Tuple& t) { return t.f8_; } +}; + +template <> +class Get<9> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(9, Tuple)) + Field(Tuple& t) { return t.f9_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(9, Tuple)) + ConstField(const Tuple& t) { return t.f9_; } +}; + +} // namespace gtest_internal + +template +GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(k, GTEST_10_TUPLE_(T))) +get(GTEST_10_TUPLE_(T)& t) { + return gtest_internal::Get::Field(t); +} + +template +GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(k, GTEST_10_TUPLE_(T))) +get(const GTEST_10_TUPLE_(T)& t) { + return gtest_internal::Get::ConstField(t); +} + +// 6.1.3.5 Relational operators + +// We only implement == and !=, as we don't have a need for the rest yet. + +namespace gtest_internal { + +// SameSizeTuplePrefixComparator::Eq(t1, t2) returns true if the +// first k fields of t1 equals the first k fields of t2. +// SameSizeTuplePrefixComparator(k1, k2) would be a compiler error if +// k1 != k2. +template +struct SameSizeTuplePrefixComparator; + +template <> +struct SameSizeTuplePrefixComparator<0, 0> { + template + static bool Eq(const Tuple1& /* t1 */, const Tuple2& /* t2 */) { + return true; + } +}; + +template +struct SameSizeTuplePrefixComparator { + template + static bool Eq(const Tuple1& t1, const Tuple2& t2) { + return SameSizeTuplePrefixComparator::Eq(t1, t2) && + ::std::tr1::get(t1) == ::std::tr1::get(t2); + } +}; + +} // namespace gtest_internal + +template +inline bool operator==(const GTEST_10_TUPLE_(T)& t, + const GTEST_10_TUPLE_(U)& u) { + return gtest_internal::SameSizeTuplePrefixComparator< + tuple_size::value, + tuple_size::value>::Eq(t, u); +} + +template +inline bool operator!=(const GTEST_10_TUPLE_(T)& t, + const GTEST_10_TUPLE_(U)& u) { return !(t == u); } + +// 6.1.4 Pairs. +// Unimplemented. + +} // namespace tr1 +} // namespace std + +#undef GTEST_0_TUPLE_ +#undef GTEST_1_TUPLE_ +#undef GTEST_2_TUPLE_ +#undef GTEST_3_TUPLE_ +#undef GTEST_4_TUPLE_ +#undef GTEST_5_TUPLE_ +#undef GTEST_6_TUPLE_ +#undef GTEST_7_TUPLE_ +#undef GTEST_8_TUPLE_ +#undef GTEST_9_TUPLE_ +#undef GTEST_10_TUPLE_ + +#undef GTEST_0_TYPENAMES_ +#undef GTEST_1_TYPENAMES_ +#undef GTEST_2_TYPENAMES_ +#undef GTEST_3_TYPENAMES_ +#undef GTEST_4_TYPENAMES_ +#undef GTEST_5_TYPENAMES_ +#undef GTEST_6_TYPENAMES_ +#undef GTEST_7_TYPENAMES_ +#undef GTEST_8_TYPENAMES_ +#undef GTEST_9_TYPENAMES_ +#undef GTEST_10_TYPENAMES_ + +#undef GTEST_DECLARE_TUPLE_AS_FRIEND_ +#undef GTEST_BY_REF_ +#undef GTEST_ADD_REF_ +#undef GTEST_TUPLE_ELEMENT_ + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ +# elif GTEST_OS_SYMBIAN + +// On Symbian, BOOST_HAS_TR1_TUPLE causes Boost's TR1 tuple library to +// use STLport's tuple implementation, which unfortunately doesn't +// work as the copy of STLport distributed with Symbian is incomplete. +// By making sure BOOST_HAS_TR1_TUPLE is undefined, we force Boost to +// use its own tuple implementation. +# ifdef BOOST_HAS_TR1_TUPLE +# undef BOOST_HAS_TR1_TUPLE +# endif // BOOST_HAS_TR1_TUPLE + +// This prevents , which defines +// BOOST_HAS_TR1_TUPLE, from being #included by Boost's . +# define BOOST_TR1_DETAIL_CONFIG_HPP_INCLUDED +# include + +# elif defined(__GNUC__) && (GTEST_GCC_VER_ >= 40000) +// GCC 4.0+ implements tr1/tuple in the header. This does +// not conform to the TR1 spec, which requires the header to be . + +# if !GTEST_HAS_RTTI && GTEST_GCC_VER_ < 40302 +// Until version 4.3.2, gcc has a bug that causes , +// which is #included by , to not compile when RTTI is +// disabled. _TR1_FUNCTIONAL is the header guard for +// . Hence the following #define is a hack to prevent +// from being included. +# define _TR1_FUNCTIONAL 1 +# include +# undef _TR1_FUNCTIONAL // Allows the user to #include + // if he chooses to. +# else +# include // NOLINT +# endif // !GTEST_HAS_RTTI && GTEST_GCC_VER_ < 40302 + +# else +// If the compiler is not GCC 4.0+, we assume the user is using a +// spec-conforming TR1 implementation. +# include // NOLINT +# endif // GTEST_USE_OWN_TR1_TUPLE + +#endif // GTEST_HAS_TR1_TUPLE + +// Determines whether clone(2) is supported. +// Usually it will only be available on Linux, excluding +// Linux on the Itanium architecture. +// Also see http://linux.die.net/man/2/clone. +#ifndef GTEST_HAS_CLONE +// The user didn't tell us, so we need to figure it out. + +# if GTEST_OS_LINUX && !defined(__ia64__) +# define GTEST_HAS_CLONE 1 +# else +# define GTEST_HAS_CLONE 0 +# endif // GTEST_OS_LINUX && !defined(__ia64__) + +#endif // GTEST_HAS_CLONE + +// Determines whether to support stream redirection. This is used to test +// output correctness and to implement death tests. +#ifndef GTEST_HAS_STREAM_REDIRECTION +// By default, we assume that stream redirection is supported on all +// platforms except known mobile ones. +# if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN +# define GTEST_HAS_STREAM_REDIRECTION 0 +# else +# define GTEST_HAS_STREAM_REDIRECTION 1 +# endif // !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_SYMBIAN +#endif // GTEST_HAS_STREAM_REDIRECTION + +// Determines whether to support death tests. +// Google Test does not support death tests for VC 7.1 and earlier as +// abort() in a VC 7.1 application compiled as GUI in debug config +// pops up a dialog window that cannot be suppressed programmatically. +#if (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ + (GTEST_OS_WINDOWS_DESKTOP && _MSC_VER >= 1400) || \ + GTEST_OS_WINDOWS_MINGW || GTEST_OS_AIX || GTEST_OS_HPUX) +# define GTEST_HAS_DEATH_TEST 1 +# include // NOLINT +#endif + +// We don't support MSVC 7.1 with exceptions disabled now. Therefore +// all the compilers we care about are adequate for supporting +// value-parameterized tests. +#define GTEST_HAS_PARAM_TEST 1 + +// Determines whether to support type-driven tests. + +// Typed tests need and variadic macros, which GCC, VC++ 8.0, +// Sun Pro CC, IBM Visual Age, and HP aCC support. +#if defined(__GNUC__) || (_MSC_VER >= 1400) || defined(__SUNPRO_CC) || \ + defined(__IBMCPP__) || defined(__HP_aCC) +# define GTEST_HAS_TYPED_TEST 1 +# define GTEST_HAS_TYPED_TEST_P 1 +#endif + +// Determines whether to support Combine(). This only makes sense when +// value-parameterized tests are enabled. The implementation doesn't +// work on Sun Studio since it doesn't understand templated conversion +// operators. +#if GTEST_HAS_PARAM_TEST && GTEST_HAS_TR1_TUPLE && !defined(__SUNPRO_CC) +# define GTEST_HAS_COMBINE 1 +#endif + +// Determines whether the system compiler uses UTF-16 for encoding wide strings. +#define GTEST_WIDE_STRING_USES_UTF16_ \ + (GTEST_OS_WINDOWS || GTEST_OS_CYGWIN || GTEST_OS_SYMBIAN || GTEST_OS_AIX) + +// Determines whether test results can be streamed to a socket. +#if GTEST_OS_LINUX +# define GTEST_CAN_STREAM_RESULTS_ 1 +#endif + +// Defines some utility macros. + +// The GNU compiler emits a warning if nested "if" statements are followed by +// an "else" statement and braces are not used to explicitly disambiguate the +// "else" binding. This leads to problems with code like: +// +// if (gate) +// ASSERT_*(condition) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#ifdef __INTEL_COMPILER +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ +#else +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT +#endif + +// Use this annotation at the end of a struct/class definition to +// prevent the compiler from optimizing away instances that are never +// used. This is useful when all interesting logic happens inside the +// c'tor and / or d'tor. Example: +// +// struct Foo { +// Foo() { ... } +// } GTEST_ATTRIBUTE_UNUSED_; +// +// Also use it after a variable or parameter declaration to tell the +// compiler the variable/parameter does not have to be used. +#if defined(__GNUC__) && !defined(COMPILER_ICC) +# define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) +#else +# define GTEST_ATTRIBUTE_UNUSED_ +#endif + +// A macro to disallow operator= +// This should be used in the private: declarations for a class. +#define GTEST_DISALLOW_ASSIGN_(type)\ + void operator=(type const &) + +// A macro to disallow copy constructor and operator= +// This should be used in the private: declarations for a class. +#define GTEST_DISALLOW_COPY_AND_ASSIGN_(type)\ + type(type const &);\ + GTEST_DISALLOW_ASSIGN_(type) + +// Tell the compiler to warn about unused return values for functions declared +// with this macro. The macro should be used on function declarations +// following the argument list: +// +// Sprocket* AllocateSprocket() GTEST_MUST_USE_RESULT_; +#if defined(__GNUC__) && (GTEST_GCC_VER_ >= 30400) && !defined(COMPILER_ICC) +# define GTEST_MUST_USE_RESULT_ __attribute__ ((warn_unused_result)) +#else +# define GTEST_MUST_USE_RESULT_ +#endif // __GNUC__ && (GTEST_GCC_VER_ >= 30400) && !COMPILER_ICC + +// Determine whether the compiler supports Microsoft's Structured Exception +// Handling. This is supported by several Windows compilers but generally +// does not exist on any other system. +#ifndef GTEST_HAS_SEH +// The user didn't tell us, so we need to figure it out. + +# if defined(_MSC_VER) || defined(__BORLANDC__) +// These two compilers are known to support SEH. +# define GTEST_HAS_SEH 1 +# else +// Assume no SEH. +# define GTEST_HAS_SEH 0 +# endif + +#endif // GTEST_HAS_SEH + +#ifdef _MSC_VER + +# if GTEST_LINKED_AS_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllimport) +# elif GTEST_CREATE_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllexport) +# endif + +#endif // _MSC_VER + +#ifndef GTEST_API_ +# define GTEST_API_ +#endif + +#ifdef __GNUC__ +// Ask the compiler to never inline a given function. +# define GTEST_NO_INLINE_ __attribute__((noinline)) +#else +# define GTEST_NO_INLINE_ +#endif + +namespace testing { + +class Message; + +namespace internal { + +class String; + +// The GTEST_COMPILE_ASSERT_ macro can be used to verify that a compile time +// expression is true. For example, you could use it to verify the +// size of a static array: +// +// GTEST_COMPILE_ASSERT_(ARRAYSIZE(content_type_names) == CONTENT_NUM_TYPES, +// content_type_names_incorrect_size); +// +// or to make sure a struct is smaller than a certain size: +// +// GTEST_COMPILE_ASSERT_(sizeof(foo) < 128, foo_too_large); +// +// The second argument to the macro is the name of the variable. If +// the expression is false, most compilers will issue a warning/error +// containing the name of the variable. + +template +struct CompileAssert { +}; + +#define GTEST_COMPILE_ASSERT_(expr, msg) \ + typedef ::testing::internal::CompileAssert<(bool(expr))> \ + msg[bool(expr) ? 1 : -1] + +// Implementation details of GTEST_COMPILE_ASSERT_: +// +// - GTEST_COMPILE_ASSERT_ works by defining an array type that has -1 +// elements (and thus is invalid) when the expression is false. +// +// - The simpler definition +// +// #define GTEST_COMPILE_ASSERT_(expr, msg) typedef char msg[(expr) ? 1 : -1] +// +// does not work, as gcc supports variable-length arrays whose sizes +// are determined at run-time (this is gcc's extension and not part +// of the C++ standard). As a result, gcc fails to reject the +// following code with the simple definition: +// +// int foo; +// GTEST_COMPILE_ASSERT_(foo, msg); // not supposed to compile as foo is +// // not a compile-time constant. +// +// - By using the type CompileAssert<(bool(expr))>, we ensures that +// expr is a compile-time constant. (Template arguments must be +// determined at compile-time.) +// +// - The outter parentheses in CompileAssert<(bool(expr))> are necessary +// to work around a bug in gcc 3.4.4 and 4.0.1. If we had written +// +// CompileAssert +// +// instead, these compilers will refuse to compile +// +// GTEST_COMPILE_ASSERT_(5 > 0, some_message); +// +// (They seem to think the ">" in "5 > 0" marks the end of the +// template argument list.) +// +// - The array size is (bool(expr) ? 1 : -1), instead of simply +// +// ((expr) ? 1 : -1). +// +// This is to avoid running into a bug in MS VC 7.1, which +// causes ((0.0) ? 1 : -1) to incorrectly evaluate to 1. + +// StaticAssertTypeEqHelper is used by StaticAssertTypeEq defined in gtest.h. +// +// This template is declared, but intentionally undefined. +template +struct StaticAssertTypeEqHelper; + +template +struct StaticAssertTypeEqHelper {}; + +#if GTEST_HAS_GLOBAL_STRING +typedef ::string string; +#else +typedef ::std::string string; +#endif // GTEST_HAS_GLOBAL_STRING + +#if GTEST_HAS_GLOBAL_WSTRING +typedef ::wstring wstring; +#elif GTEST_HAS_STD_WSTRING +typedef ::std::wstring wstring; +#endif // GTEST_HAS_GLOBAL_WSTRING + +// A helper for suppressing warnings on constant condition. It just +// returns 'condition'. +GTEST_API_ bool IsTrue(bool condition); + +// Defines scoped_ptr. + +// This implementation of scoped_ptr is PARTIAL - it only contains +// enough stuff to satisfy Google Test's need. +template +class scoped_ptr { + public: + typedef T element_type; + + explicit scoped_ptr(T* p = NULL) : ptr_(p) {} + ~scoped_ptr() { reset(); } + + T& operator*() const { return *ptr_; } + T* operator->() const { return ptr_; } + T* get() const { return ptr_; } + + T* release() { + T* const ptr = ptr_; + ptr_ = NULL; + return ptr; + } + + void reset(T* p = NULL) { + if (p != ptr_) { + if (IsTrue(sizeof(T) > 0)) { // Makes sure T is a complete type. + delete ptr_; + } + ptr_ = p; + } + } + private: + T* ptr_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(scoped_ptr); +}; + +// Defines RE. + +// A simple C++ wrapper for . It uses the POSIX Extended +// Regular Expression syntax. +class GTEST_API_ RE { + public: + // A copy constructor is required by the Standard to initialize object + // references from r-values. + RE(const RE& other) { Init(other.pattern()); } + + // Constructs an RE from a string. + RE(const ::std::string& regex) { Init(regex.c_str()); } // NOLINT + +#if GTEST_HAS_GLOBAL_STRING + + RE(const ::string& regex) { Init(regex.c_str()); } // NOLINT + +#endif // GTEST_HAS_GLOBAL_STRING + + RE(const char* regex) { Init(regex); } // NOLINT + ~RE(); + + // Returns the string representation of the regex. + const char* pattern() const { return pattern_; } + + // FullMatch(str, re) returns true iff regular expression re matches + // the entire str. + // PartialMatch(str, re) returns true iff regular expression re + // matches a substring of str (including str itself). + // + // TODO(wan@google.com): make FullMatch() and PartialMatch() work + // when str contains NUL characters. + static bool FullMatch(const ::std::string& str, const RE& re) { + return FullMatch(str.c_str(), re); + } + static bool PartialMatch(const ::std::string& str, const RE& re) { + return PartialMatch(str.c_str(), re); + } + +#if GTEST_HAS_GLOBAL_STRING + + static bool FullMatch(const ::string& str, const RE& re) { + return FullMatch(str.c_str(), re); + } + static bool PartialMatch(const ::string& str, const RE& re) { + return PartialMatch(str.c_str(), re); + } + +#endif // GTEST_HAS_GLOBAL_STRING + + static bool FullMatch(const char* str, const RE& re); + static bool PartialMatch(const char* str, const RE& re); + + private: + void Init(const char* regex); + + // We use a const char* instead of a string, as Google Test may be used + // where string is not available. We also do not use Google Test's own + // String type here, in order to simplify dependencies between the + // files. + const char* pattern_; + bool is_valid_; + +#if GTEST_USES_POSIX_RE + + regex_t full_regex_; // For FullMatch(). + regex_t partial_regex_; // For PartialMatch(). + +#else // GTEST_USES_SIMPLE_RE + + const char* full_pattern_; // For FullMatch(); + +#endif + + GTEST_DISALLOW_ASSIGN_(RE); +}; + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line); + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation(const char* file, + int line); + +// Defines logging utilities: +// GTEST_LOG_(severity) - logs messages at the specified severity level. The +// message itself is streamed into the macro. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. + +enum GTestLogSeverity { + GTEST_INFO, + GTEST_WARNING, + GTEST_ERROR, + GTEST_FATAL +}; + +// Formats log entry severity, provides a stream object for streaming the +// log message, and terminates the message with a newline when going out of +// scope. +class GTEST_API_ GTestLog { + public: + GTestLog(GTestLogSeverity severity, const char* file, int line); + + // Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. + ~GTestLog(); + + ::std::ostream& GetStream() { return ::std::cerr; } + + private: + const GTestLogSeverity severity_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestLog); +}; + +#define GTEST_LOG_(severity) \ + ::testing::internal::GTestLog(::testing::internal::GTEST_##severity, \ + __FILE__, __LINE__).GetStream() + +inline void LogToStderr() {} +inline void FlushInfoLog() { fflush(NULL); } + +// INTERNAL IMPLEMENTATION - DO NOT USE. +// +// GTEST_CHECK_ is an all-mode assert. It aborts the program if the condition +// is not satisfied. +// Synopsys: +// GTEST_CHECK_(boolean_condition); +// or +// GTEST_CHECK_(boolean_condition) << "Additional message"; +// +// This checks the condition and if the condition is not satisfied +// it prints message about the condition violation, including the +// condition itself, plus additional message streamed into it, if any, +// and then it aborts the program. It aborts the program irrespective of +// whether it is built in the debug mode or not. +#define GTEST_CHECK_(condition) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::IsTrue(condition)) \ + ; \ + else \ + GTEST_LOG_(FATAL) << "Condition " #condition " failed. " + +// An all-mode assert to verify that the given POSIX-style function +// call returns 0 (indicating success). Known limitation: this +// doesn't expand to a balanced 'if' statement, so enclose the macro +// in {} if you need to use it as the only statement in an 'if' +// branch. +#define GTEST_CHECK_POSIX_SUCCESS_(posix_call) \ + if (const int gtest_error = (posix_call)) \ + GTEST_LOG_(FATAL) << #posix_call << "failed with error " \ + << gtest_error + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Use ImplicitCast_ as a safe version of static_cast for upcasting in +// the type hierarchy (e.g. casting a Foo* to a SuperclassOfFoo* or a +// const Foo*). When you use ImplicitCast_, the compiler checks that +// the cast is safe. Such explicit ImplicitCast_s are necessary in +// surprisingly many situations where C++ demands an exact type match +// instead of an argument type convertable to a target type. +// +// The syntax for using ImplicitCast_ is the same as for static_cast: +// +// ImplicitCast_(expr) +// +// ImplicitCast_ would have been part of the C++ standard library, +// but the proposal was submitted too late. It will probably make +// its way into the language in the future. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., implicit_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template +inline To ImplicitCast_(To x) { return x; } + +// When you upcast (that is, cast a pointer from type Foo to type +// SuperclassOfFoo), it's fine to use ImplicitCast_<>, since upcasts +// always succeed. When you downcast (that is, cast a pointer from +// type Foo to type SubclassOfFoo), static_cast<> isn't safe, because +// how do you know the pointer is really of type SubclassOfFoo? It +// could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, +// when you downcast, you should use this macro. In debug mode, we +// use dynamic_cast<> to double-check the downcast is legal (we die +// if it's not). In normal mode, we do the efficient static_cast<> +// instead. Thus, it's important to test in debug mode to make sure +// the cast is legal! +// This is the only place in the code we should use dynamic_cast<>. +// In particular, you SHOULDN'T be using dynamic_cast<> in order to +// do RTTI (eg code like this: +// if (dynamic_cast(foo)) HandleASubclass1Object(foo); +// if (dynamic_cast(foo)) HandleASubclass2Object(foo); +// You should design the code some other way not to need this. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., down_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template // use like this: DownCast_(foo); +inline To DownCast_(From* f) { // so we only accept pointers + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + const To to = NULL; + ::testing::internal::ImplicitCast_(to); + } + +#if GTEST_HAS_RTTI + // RTTI: debug mode only! + GTEST_CHECK_(f == NULL || dynamic_cast(f) != NULL); +#endif + return static_cast(f); +} + +// Downcasts the pointer of type Base to Derived. +// Derived must be a subclass of Base. The parameter MUST +// point to a class of type Derived, not any subclass of it. +// When RTTI is available, the function performs a runtime +// check to enforce this. +template +Derived* CheckedDowncastToActualType(Base* base) { +#if GTEST_HAS_RTTI + GTEST_CHECK_(typeid(*base) == typeid(Derived)); + return dynamic_cast(base); // NOLINT +#else + return static_cast(base); // Poor man's downcast. +#endif +} + +#if GTEST_HAS_STREAM_REDIRECTION + +// Defines the stderr capturer: +// CaptureStdout - starts capturing stdout. +// GetCapturedStdout - stops capturing stdout and returns the captured string. +// CaptureStderr - starts capturing stderr. +// GetCapturedStderr - stops capturing stderr and returns the captured string. +// +GTEST_API_ void CaptureStdout(); +GTEST_API_ String GetCapturedStdout(); +GTEST_API_ void CaptureStderr(); +GTEST_API_ String GetCapturedStderr(); + +#endif // GTEST_HAS_STREAM_REDIRECTION + + +#if GTEST_HAS_DEATH_TEST + +// A copy of all command line arguments. Set by InitGoogleTest(). +extern ::std::vector g_argvs; + +// GTEST_HAS_DEATH_TEST implies we have ::std::string. +const ::std::vector& GetArgvs(); + +#endif // GTEST_HAS_DEATH_TEST + +// Defines synchronization primitives. + +#if GTEST_HAS_PTHREAD + +// Sleeps for (roughly) n milli-seconds. This function is only for +// testing Google Test's own constructs. Don't use it in user tests, +// either directly or indirectly. +inline void SleepMilliseconds(int n) { + const timespec time = { + 0, // 0 seconds. + n * 1000L * 1000L, // And n ms. + }; + nanosleep(&time, NULL); +} + +// Allows a controller thread to pause execution of newly created +// threads until notified. Instances of this class must be created +// and destroyed in the controller thread. +// +// This class is only for testing Google Test's own constructs. Do not +// use it in user tests, either directly or indirectly. +class Notification { + public: + Notification() : notified_(false) {} + + // Notifies all threads created with this notification to start. Must + // be called from the controller thread. + void Notify() { notified_ = true; } + + // Blocks until the controller thread notifies. Must be called from a test + // thread. + void WaitForNotification() { + while(!notified_) { + SleepMilliseconds(10); + } + } + + private: + volatile bool notified_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(Notification); +}; + +// As a C-function, ThreadFuncWithCLinkage cannot be templated itself. +// Consequently, it cannot select a correct instantiation of ThreadWithParam +// in order to call its Run(). Introducing ThreadWithParamBase as a +// non-templated base class for ThreadWithParam allows us to bypass this +// problem. +class ThreadWithParamBase { + public: + virtual ~ThreadWithParamBase() {} + virtual void Run() = 0; +}; + +// pthread_create() accepts a pointer to a function type with the C linkage. +// According to the Standard (7.5/1), function types with different linkages +// are different even if they are otherwise identical. Some compilers (for +// example, SunStudio) treat them as different types. Since class methods +// cannot be defined with C-linkage we need to define a free C-function to +// pass into pthread_create(). +extern "C" inline void* ThreadFuncWithCLinkage(void* thread) { + static_cast(thread)->Run(); + return NULL; +} + +// Helper class for testing Google Test's multi-threading constructs. +// To use it, write: +// +// void ThreadFunc(int param) { /* Do things with param */ } +// Notification thread_can_start; +// ... +// // The thread_can_start parameter is optional; you can supply NULL. +// ThreadWithParam thread(&ThreadFunc, 5, &thread_can_start); +// thread_can_start.Notify(); +// +// These classes are only for testing Google Test's own constructs. Do +// not use them in user tests, either directly or indirectly. +template +class ThreadWithParam : public ThreadWithParamBase { + public: + typedef void (*UserThreadFunc)(T); + + ThreadWithParam( + UserThreadFunc func, T param, Notification* thread_can_start) + : func_(func), + param_(param), + thread_can_start_(thread_can_start), + finished_(false) { + ThreadWithParamBase* const base = this; + // The thread can be created only after all fields except thread_ + // have been initialized. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_create(&thread_, 0, &ThreadFuncWithCLinkage, base)); + } + ~ThreadWithParam() { Join(); } + + void Join() { + if (!finished_) { + GTEST_CHECK_POSIX_SUCCESS_(pthread_join(thread_, 0)); + finished_ = true; + } + } + + virtual void Run() { + if (thread_can_start_ != NULL) + thread_can_start_->WaitForNotification(); + func_(param_); + } + + private: + const UserThreadFunc func_; // User-supplied thread function. + const T param_; // User-supplied parameter to the thread function. + // When non-NULL, used to block execution until the controller thread + // notifies. + Notification* const thread_can_start_; + bool finished_; // true iff we know that the thread function has finished. + pthread_t thread_; // The native thread object. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); +}; + +// MutexBase and Mutex implement mutex on pthreads-based platforms. They +// are used in conjunction with class MutexLock: +// +// Mutex mutex; +// ... +// MutexLock lock(&mutex); // Acquires the mutex and releases it at the end +// // of the current scope. +// +// MutexBase implements behavior for both statically and dynamically +// allocated mutexes. Do not use MutexBase directly. Instead, write +// the following to define a static mutex: +// +// GTEST_DEFINE_STATIC_MUTEX_(g_some_mutex); +// +// You can forward declare a static mutex like this: +// +// GTEST_DECLARE_STATIC_MUTEX_(g_some_mutex); +// +// To create a dynamic mutex, just define an object of type Mutex. +class MutexBase { + public: + // Acquires this mutex. + void Lock() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_lock(&mutex_)); + owner_ = pthread_self(); + } + + // Releases this mutex. + void Unlock() { + // We don't protect writing to owner_ here, as it's the caller's + // responsibility to ensure that the current thread holds the + // mutex when this is called. + owner_ = 0; + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_unlock(&mutex_)); + } + + // Does nothing if the current thread holds the mutex. Otherwise, crashes + // with high probability. + void AssertHeld() const { + GTEST_CHECK_(owner_ == pthread_self()) + << "The current thread is not holding the mutex @" << this; + } + + // A static mutex may be used before main() is entered. It may even + // be used before the dynamic initialization stage. Therefore we + // must be able to initialize a static mutex object at link time. + // This means MutexBase has to be a POD and its member variables + // have to be public. + public: + pthread_mutex_t mutex_; // The underlying pthread mutex. + pthread_t owner_; // The thread holding the mutex; 0 means no one holds it. +}; + +// Forward-declares a static mutex. +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::MutexBase mutex + +// Defines and statically (i.e. at link time) initializes a static mutex. +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ + ::testing::internal::MutexBase mutex = { PTHREAD_MUTEX_INITIALIZER, 0 } + +// The Mutex class can only be used for mutexes created at runtime. It +// shares its API with MutexBase otherwise. +class Mutex : public MutexBase { + public: + Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_init(&mutex_, NULL)); + owner_ = 0; + } + ~Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_destroy(&mutex_)); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); +}; + +// We cannot name this class MutexLock as the ctor declaration would +// conflict with a macro named MutexLock, which is defined on some +// platforms. Hence the typedef trick below. +class GTestMutexLock { + public: + explicit GTestMutexLock(MutexBase* mutex) + : mutex_(mutex) { mutex_->Lock(); } + + ~GTestMutexLock() { mutex_->Unlock(); } + + private: + MutexBase* const mutex_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); +}; + +typedef GTestMutexLock MutexLock; + +// Helpers for ThreadLocal. + +// pthread_key_create() requires DeleteThreadLocalValue() to have +// C-linkage. Therefore it cannot be templatized to access +// ThreadLocal. Hence the need for class +// ThreadLocalValueHolderBase. +class ThreadLocalValueHolderBase { + public: + virtual ~ThreadLocalValueHolderBase() {} +}; + +// Called by pthread to delete thread-local data stored by +// pthread_setspecific(). +extern "C" inline void DeleteThreadLocalValue(void* value_holder) { + delete static_cast(value_holder); +} + +// Implements thread-local storage on pthreads-based systems. +// +// // Thread 1 +// ThreadLocal tl(100); // 100 is the default value for each thread. +// +// // Thread 2 +// tl.set(150); // Changes the value for thread 2 only. +// EXPECT_EQ(150, tl.get()); +// +// // Thread 1 +// EXPECT_EQ(100, tl.get()); // In thread 1, tl has the original value. +// tl.set(200); +// EXPECT_EQ(200, tl.get()); +// +// The template type argument T must have a public copy constructor. +// In addition, the default ThreadLocal constructor requires T to have +// a public default constructor. +// +// An object managed for a thread by a ThreadLocal instance is deleted +// when the thread exits. Or, if the ThreadLocal instance dies in +// that thread, when the ThreadLocal dies. It's the user's +// responsibility to ensure that all other threads using a ThreadLocal +// have exited when it dies, or the per-thread objects for those +// threads will not be deleted. +// +// Google Test only uses global ThreadLocal objects. That means they +// will die after main() has returned. Therefore, no per-thread +// object managed by Google Test will be leaked as long as all threads +// using Google Test have exited when main() returns. +template +class ThreadLocal { + public: + ThreadLocal() : key_(CreateKey()), + default_() {} + explicit ThreadLocal(const T& value) : key_(CreateKey()), + default_(value) {} + + ~ThreadLocal() { + // Destroys the managed object for the current thread, if any. + DeleteThreadLocalValue(pthread_getspecific(key_)); + + // Releases resources associated with the key. This will *not* + // delete managed objects for other threads. + GTEST_CHECK_POSIX_SUCCESS_(pthread_key_delete(key_)); + } + + T* pointer() { return GetOrCreateValue(); } + const T* pointer() const { return GetOrCreateValue(); } + const T& get() const { return *pointer(); } + void set(const T& value) { *pointer() = value; } + + private: + // Holds a value of type T. + class ValueHolder : public ThreadLocalValueHolderBase { + public: + explicit ValueHolder(const T& value) : value_(value) {} + + T* pointer() { return &value_; } + + private: + T value_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); + }; + + static pthread_key_t CreateKey() { + pthread_key_t key; + // When a thread exits, DeleteThreadLocalValue() will be called on + // the object managed for that thread. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_key_create(&key, &DeleteThreadLocalValue)); + return key; + } + + T* GetOrCreateValue() const { + ThreadLocalValueHolderBase* const holder = + static_cast(pthread_getspecific(key_)); + if (holder != NULL) { + return CheckedDowncastToActualType(holder)->pointer(); + } + + ValueHolder* const new_holder = new ValueHolder(default_); + ThreadLocalValueHolderBase* const holder_base = new_holder; + GTEST_CHECK_POSIX_SUCCESS_(pthread_setspecific(key_, holder_base)); + return new_holder->pointer(); + } + + // A key pthreads uses for looking up per-thread values. + const pthread_key_t key_; + const T default_; // The default value for each thread. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); +}; + +# define GTEST_IS_THREADSAFE 1 + +#else // GTEST_HAS_PTHREAD + +// A dummy implementation of synchronization primitives (mutex, lock, +// and thread-local variable). Necessary for compiling Google Test where +// mutex is not supported - using Google Test in multiple threads is not +// supported on such platforms. + +class Mutex { + public: + Mutex() {} + void AssertHeld() const {} +}; + +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::Mutex mutex + +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) ::testing::internal::Mutex mutex + +class GTestMutexLock { + public: + explicit GTestMutexLock(Mutex*) {} // NOLINT +}; + +typedef GTestMutexLock MutexLock; + +template +class ThreadLocal { + public: + ThreadLocal() : value_() {} + explicit ThreadLocal(const T& value) : value_(value) {} + T* pointer() { return &value_; } + const T* pointer() const { return &value_; } + const T& get() const { return value_; } + void set(const T& value) { value_ = value; } + private: + T value_; +}; + +// The above synchronization primitives have dummy implementations. +// Therefore Google Test is not thread-safe. +# define GTEST_IS_THREADSAFE 0 + +#endif // GTEST_HAS_PTHREAD + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +GTEST_API_ size_t GetThreadCount(); + +// Passing non-POD classes through ellipsis (...) crashes the ARM +// compiler and generates a warning in Sun Studio. The Nokia Symbian +// and the IBM XL C/C++ compiler try to instantiate a copy constructor +// for objects passed through ellipsis (...), failing for uncopyable +// objects. We define this to ensure that only POD is passed through +// ellipsis on these systems. +#if defined(__SYMBIAN32__) || defined(__IBMCPP__) || defined(__SUNPRO_CC) +// We lose support for NULL detection where the compiler doesn't like +// passing non-POD classes through ellipsis (...). +# define GTEST_ELLIPSIS_NEEDS_POD_ 1 +#else +# define GTEST_CAN_COMPARE_NULL 1 +#endif + +// The Nokia Symbian and IBM XL C/C++ compilers cannot decide between +// const T& and const T* in a function template. These compilers +// _can_ decide between class template specializations for T and T*, +// so a tr1::type_traits-like is_pointer works. +#if defined(__SYMBIAN32__) || defined(__IBMCPP__) +# define GTEST_NEEDS_IS_POINTER_ 1 +#endif + +template +struct bool_constant { + typedef bool_constant type; + static const bool value = bool_value; +}; +template const bool bool_constant::value; + +typedef bool_constant false_type; +typedef bool_constant true_type; + +template +struct is_pointer : public false_type {}; + +template +struct is_pointer : public true_type {}; + +template +struct IteratorTraits { + typedef typename Iterator::value_type value_type; +}; + +template +struct IteratorTraits { + typedef T value_type; +}; + +template +struct IteratorTraits { + typedef T value_type; +}; + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_SEP_ "\\" +# define GTEST_HAS_ALT_PATH_SEP_ 1 +// The biggest signed integer type the compiler supports. +typedef __int64 BiggestInt; +#else +# define GTEST_PATH_SEP_ "/" +# define GTEST_HAS_ALT_PATH_SEP_ 0 +typedef long long BiggestInt; // NOLINT +#endif // GTEST_OS_WINDOWS + +// Utilities for char. + +// isspace(int ch) and friends accept an unsigned char or EOF. char +// may be signed, depending on the compiler (or compiler flags). +// Therefore we need to cast a char to unsigned char before calling +// isspace(), etc. + +inline bool IsAlpha(char ch) { + return isalpha(static_cast(ch)) != 0; +} +inline bool IsAlNum(char ch) { + return isalnum(static_cast(ch)) != 0; +} +inline bool IsDigit(char ch) { + return isdigit(static_cast(ch)) != 0; +} +inline bool IsLower(char ch) { + return islower(static_cast(ch)) != 0; +} +inline bool IsSpace(char ch) { + return isspace(static_cast(ch)) != 0; +} +inline bool IsUpper(char ch) { + return isupper(static_cast(ch)) != 0; +} +inline bool IsXDigit(char ch) { + return isxdigit(static_cast(ch)) != 0; +} + +inline char ToLower(char ch) { + return static_cast(tolower(static_cast(ch))); +} +inline char ToUpper(char ch) { + return static_cast(toupper(static_cast(ch))); +} + +// The testing::internal::posix namespace holds wrappers for common +// POSIX functions. These wrappers hide the differences between +// Windows/MSVC and POSIX systems. Since some compilers define these +// standard functions as macros, the wrapper cannot have the same name +// as the wrapped function. + +namespace posix { + +// Functions with a different name on Windows. + +#if GTEST_OS_WINDOWS + +typedef struct _stat StatStruct; + +# ifdef __BORLANDC__ +inline int IsATTY(int fd) { return isatty(fd); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +# else // !__BORLANDC__ +# if GTEST_OS_WINDOWS_MOBILE +inline int IsATTY(int /* fd */) { return 0; } +# else +inline int IsATTY(int fd) { return _isatty(fd); } +# endif // GTEST_OS_WINDOWS_MOBILE +inline int StrCaseCmp(const char* s1, const char* s2) { + return _stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return _strdup(src); } +# endif // __BORLANDC__ + +# if GTEST_OS_WINDOWS_MOBILE +inline int FileNo(FILE* file) { return reinterpret_cast(_fileno(file)); } +// Stat(), RmDir(), and IsDir() are not needed on Windows CE at this +// time and thus not defined there. +# else +inline int FileNo(FILE* file) { return _fileno(file); } +inline int Stat(const char* path, StatStruct* buf) { return _stat(path, buf); } +inline int RmDir(const char* dir) { return _rmdir(dir); } +inline bool IsDir(const StatStruct& st) { + return (_S_IFDIR & st.st_mode) != 0; +} +# endif // GTEST_OS_WINDOWS_MOBILE + +#else + +typedef struct stat StatStruct; + +inline int FileNo(FILE* file) { return fileno(file); } +inline int IsATTY(int fd) { return isatty(fd); } +inline int Stat(const char* path, StatStruct* buf) { return stat(path, buf); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return strcasecmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +inline int RmDir(const char* dir) { return rmdir(dir); } +inline bool IsDir(const StatStruct& st) { return S_ISDIR(st.st_mode); } + +#endif // GTEST_OS_WINDOWS + +// Functions deprecated by MSVC 8.0. + +#ifdef _MSC_VER +// Temporarily disable warning 4996 (deprecated function). +# pragma warning(push) +# pragma warning(disable:4996) +#endif + +inline const char* StrNCpy(char* dest, const char* src, size_t n) { + return strncpy(dest, src, n); +} + +// ChDir(), FReopen(), FDOpen(), Read(), Write(), Close(), and +// StrError() aren't needed on Windows CE at this time and thus not +// defined there. + +#if !GTEST_OS_WINDOWS_MOBILE +inline int ChDir(const char* dir) { return chdir(dir); } +#endif +inline FILE* FOpen(const char* path, const char* mode) { + return fopen(path, mode); +} +#if !GTEST_OS_WINDOWS_MOBILE +inline FILE *FReopen(const char* path, const char* mode, FILE* stream) { + return freopen(path, mode, stream); +} +inline FILE* FDOpen(int fd, const char* mode) { return fdopen(fd, mode); } +#endif +inline int FClose(FILE* fp) { return fclose(fp); } +#if !GTEST_OS_WINDOWS_MOBILE +inline int Read(int fd, void* buf, unsigned int count) { + return static_cast(read(fd, buf, count)); +} +inline int Write(int fd, const void* buf, unsigned int count) { + return static_cast(write(fd, buf, count)); +} +inline int Close(int fd) { return close(fd); } +inline const char* StrError(int errnum) { return strerror(errnum); } +#endif +inline const char* GetEnv(const char* name) { +#if GTEST_OS_WINDOWS_MOBILE + // We are on Windows CE, which has no environment variables. + return NULL; +#elif defined(__BORLANDC__) || defined(__SunOS_5_8) || defined(__SunOS_5_9) + // Environment variables which we programmatically clear will be set to the + // empty string rather than unset (NULL). Handle that case. + const char* const env = getenv(name); + return (env != NULL && env[0] != '\0') ? env : NULL; +#else + return getenv(name); +#endif +} + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif + +#if GTEST_OS_WINDOWS_MOBILE +// Windows CE has no C library. The abort() function is used in +// several places in Google Test. This implementation provides a reasonable +// imitation of standard behaviour. +void Abort(); +#else +inline void Abort() { abort(); } +#endif // GTEST_OS_WINDOWS_MOBILE + +} // namespace posix + +// The maximum number a BiggestInt can represent. This definition +// works no matter BiggestInt is represented in one's complement or +// two's complement. +// +// We cannot rely on numeric_limits in STL, as __int64 and long long +// are not part of standard C++ and numeric_limits doesn't need to be +// defined for them. +const BiggestInt kMaxBiggestInt = + ~(static_cast(1) << (8*sizeof(BiggestInt) - 1)); + +// This template class serves as a compile-time function from size to +// type. It maps a size in bytes to a primitive type with that +// size. e.g. +// +// TypeWithSize<4>::UInt +// +// is typedef-ed to be unsigned int (unsigned integer made up of 4 +// bytes). +// +// Such functionality should belong to STL, but I cannot find it +// there. +// +// Google Test uses this class in the implementation of floating-point +// comparison. +// +// For now it only handles UInt (unsigned int) as that's all Google Test +// needs. Other types can be easily added in the future if need +// arises. +template +class TypeWithSize { + public: + // This prevents the user from using TypeWithSize with incorrect + // values of N. + typedef void UInt; +}; + +// The specialization for size 4. +template <> +class TypeWithSize<4> { + public: + // unsigned int has size 4 in both gcc and MSVC. + // + // As base/basictypes.h doesn't compile on Windows, we cannot use + // uint32, uint64, and etc here. + typedef int Int; + typedef unsigned int UInt; +}; + +// The specialization for size 8. +template <> +class TypeWithSize<8> { + public: + +#if GTEST_OS_WINDOWS + typedef __int64 Int; + typedef unsigned __int64 UInt; +#else + typedef long long Int; // NOLINT + typedef unsigned long long UInt; // NOLINT +#endif // GTEST_OS_WINDOWS +}; + +// Integer types of known sizes. +typedef TypeWithSize<4>::Int Int32; +typedef TypeWithSize<4>::UInt UInt32; +typedef TypeWithSize<8>::Int Int64; +typedef TypeWithSize<8>::UInt UInt64; +typedef TypeWithSize<8>::Int TimeInMillis; // Represents time in milliseconds. + +// Utilities for command line flags and environment variables. + +// Macro for referencing flags. +#define GTEST_FLAG(name) FLAGS_gtest_##name + +// Macros for declaring flags. +#define GTEST_DECLARE_bool_(name) GTEST_API_ extern bool GTEST_FLAG(name) +#define GTEST_DECLARE_int32_(name) \ + GTEST_API_ extern ::testing::internal::Int32 GTEST_FLAG(name) +#define GTEST_DECLARE_string_(name) \ + GTEST_API_ extern ::testing::internal::String GTEST_FLAG(name) + +// Macros for defining flags. +#define GTEST_DEFINE_bool_(name, default_val, doc) \ + GTEST_API_ bool GTEST_FLAG(name) = (default_val) +#define GTEST_DEFINE_int32_(name, default_val, doc) \ + GTEST_API_ ::testing::internal::Int32 GTEST_FLAG(name) = (default_val) +#define GTEST_DEFINE_string_(name, default_val, doc) \ + GTEST_API_ ::testing::internal::String GTEST_FLAG(name) = (default_val) + +// Parses 'str' for a 32-bit signed integer. If successful, writes the result +// to *value and returns true; otherwise leaves *value unchanged and returns +// false. +// TODO(chandlerc): Find a better way to refactor flag and environment parsing +// out of both gtest-port.cc and gtest.cc to avoid exporting this utility +// function. +bool ParseInt32(const Message& src_text, const char* str, Int32* value); + +// Parses a bool/Int32/string from the environment variable +// corresponding to the given Google Test flag. +bool BoolFromGTestEnv(const char* flag, bool default_val); +GTEST_API_ Int32 Int32FromGTestEnv(const char* flag, Int32 default_val); +const char* StringFromGTestEnv(const char* flag, const char* default_val); + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ + +#if GTEST_OS_LINUX +# include +# include +# include +# include +#endif // GTEST_OS_LINUX + +#include +#include +#include +#include +#include + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file declares the String class and functions used internally by +// Google Test. They are subject to change without notice. They should not used +// by code external to Google Test. +// +// This header file is #included by . +// It should not be #included by other files. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ + +#ifdef __BORLANDC__ +// string.h is not guaranteed to provide strcpy on C++ Builder. +# include +#endif + +#include + +#include + +namespace testing { +namespace internal { + +// String - a UTF-8 string class. +// +// For historic reasons, we don't use std::string. +// +// TODO(wan@google.com): replace this class with std::string or +// implement it in terms of the latter. +// +// Note that String can represent both NULL and the empty string, +// while std::string cannot represent NULL. +// +// NULL and the empty string are considered different. NULL is less +// than anything (including the empty string) except itself. +// +// This class only provides minimum functionality necessary for +// implementing Google Test. We do not intend to implement a full-fledged +// string class here. +// +// Since the purpose of this class is to provide a substitute for +// std::string on platforms where it cannot be used, we define a copy +// constructor and assignment operators such that we don't need +// conditional compilation in a lot of places. +// +// In order to make the representation efficient, the d'tor of String +// is not virtual. Therefore DO NOT INHERIT FROM String. +class GTEST_API_ String { + public: + // Static utility methods + + // Returns the input enclosed in double quotes if it's not NULL; + // otherwise returns "(null)". For example, "\"Hello\"" is returned + // for input "Hello". + // + // This is useful for printing a C string in the syntax of a literal. + // + // Known issue: escape sequences are not handled yet. + static String ShowCStringQuoted(const char* c_str); + + // Clones a 0-terminated C string, allocating memory using new. The + // caller is responsible for deleting the return value using + // delete[]. Returns the cloned string, or NULL if the input is + // NULL. + // + // This is different from strdup() in string.h, which allocates + // memory using malloc(). + static const char* CloneCString(const char* c_str); + +#if GTEST_OS_WINDOWS_MOBILE + // Windows CE does not have the 'ANSI' versions of Win32 APIs. To be + // able to pass strings to Win32 APIs on CE we need to convert them + // to 'Unicode', UTF-16. + + // Creates a UTF-16 wide string from the given ANSI string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the wide string, or NULL if the + // input is NULL. + // + // The wide string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static LPCWSTR AnsiToUtf16(const char* c_str); + + // Creates an ANSI string from the given wide string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the ANSI string, or NULL if the + // input is NULL. + // + // The returned string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static const char* Utf16ToAnsi(LPCWSTR utf16_str); +#endif + + // Compares two C strings. Returns true iff they have the same content. + // + // Unlike strcmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CStringEquals(const char* lhs, const char* rhs); + + // Converts a wide C string to a String using the UTF-8 encoding. + // NULL will be converted to "(null)". If an error occurred during + // the conversion, "(failed to convert from wide string)" is + // returned. + static String ShowWideCString(const wchar_t* wide_c_str); + + // Similar to ShowWideCString(), except that this function encloses + // the converted string in double quotes. + static String ShowWideCStringQuoted(const wchar_t* wide_c_str); + + // Compares two wide C strings. Returns true iff they have the same + // content. + // + // Unlike wcscmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool WideCStringEquals(const wchar_t* lhs, const wchar_t* rhs); + + // Compares two C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike strcasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CaseInsensitiveCStringEquals(const char* lhs, + const char* rhs); + + // Compares two wide C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike wcscasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL wide C string, + // including the empty string. + // NB: The implementations on different platforms slightly differ. + // On windows, this method uses _wcsicmp which compares according to LC_CTYPE + // environment variable. On GNU platform this method uses wcscasecmp + // which compares according to LC_CTYPE category of the current locale. + // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the + // current locale. + static bool CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs); + + // Formats a list of arguments to a String, using the same format + // spec string as for printf. + // + // We do not use the StringPrintf class as it is not universally + // available. + // + // The result is limited to 4096 characters (including the tailing + // 0). If 4096 characters are not enough to format the input, + // "" is returned. + static String Format(const char* format, ...); + + // C'tors + + // The default c'tor constructs a NULL string. + String() : c_str_(NULL), length_(0) {} + + // Constructs a String by cloning a 0-terminated C string. + String(const char* a_c_str) { // NOLINT + if (a_c_str == NULL) { + c_str_ = NULL; + length_ = 0; + } else { + ConstructNonNull(a_c_str, strlen(a_c_str)); + } + } + + // Constructs a String by copying a given number of chars from a + // buffer. E.g. String("hello", 3) creates the string "hel", + // String("a\0bcd", 4) creates "a\0bc", String(NULL, 0) creates "", + // and String(NULL, 1) results in access violation. + String(const char* buffer, size_t a_length) { + ConstructNonNull(buffer, a_length); + } + + // The copy c'tor creates a new copy of the string. The two + // String objects do not share content. + String(const String& str) : c_str_(NULL), length_(0) { *this = str; } + + // D'tor. String is intended to be a final class, so the d'tor + // doesn't need to be virtual. + ~String() { delete[] c_str_; } + + // Allows a String to be implicitly converted to an ::std::string or + // ::string, and vice versa. Converting a String containing a NULL + // pointer to ::std::string or ::string is undefined behavior. + // Converting a ::std::string or ::string containing an embedded NUL + // character to a String will result in the prefix up to the first + // NUL character. + String(const ::std::string& str) { + ConstructNonNull(str.c_str(), str.length()); + } + + operator ::std::string() const { return ::std::string(c_str(), length()); } + +#if GTEST_HAS_GLOBAL_STRING + String(const ::string& str) { + ConstructNonNull(str.c_str(), str.length()); + } + + operator ::string() const { return ::string(c_str(), length()); } +#endif // GTEST_HAS_GLOBAL_STRING + + // Returns true iff this is an empty string (i.e. ""). + bool empty() const { return (c_str() != NULL) && (length() == 0); } + + // Compares this with another String. + // Returns < 0 if this is less than rhs, 0 if this is equal to rhs, or > 0 + // if this is greater than rhs. + int Compare(const String& rhs) const; + + // Returns true iff this String equals the given C string. A NULL + // string and a non-NULL string are considered not equal. + bool operator==(const char* a_c_str) const { return Compare(a_c_str) == 0; } + + // Returns true iff this String is less than the given String. A + // NULL string is considered less than "". + bool operator<(const String& rhs) const { return Compare(rhs) < 0; } + + // Returns true iff this String doesn't equal the given C string. A NULL + // string and a non-NULL string are considered not equal. + bool operator!=(const char* a_c_str) const { return !(*this == a_c_str); } + + // Returns true iff this String ends with the given suffix. *Any* + // String is considered to end with a NULL or empty suffix. + bool EndsWith(const char* suffix) const; + + // Returns true iff this String ends with the given suffix, not considering + // case. Any String is considered to end with a NULL or empty suffix. + bool EndsWithCaseInsensitive(const char* suffix) const; + + // Returns the length of the encapsulated string, or 0 if the + // string is NULL. + size_t length() const { return length_; } + + // Gets the 0-terminated C string this String object represents. + // The String object still owns the string. Therefore the caller + // should NOT delete the return value. + const char* c_str() const { return c_str_; } + + // Assigns a C string to this object. Self-assignment works. + const String& operator=(const char* a_c_str) { + return *this = String(a_c_str); + } + + // Assigns a String object to this object. Self-assignment works. + const String& operator=(const String& rhs) { + if (this != &rhs) { + delete[] c_str_; + if (rhs.c_str() == NULL) { + c_str_ = NULL; + length_ = 0; + } else { + ConstructNonNull(rhs.c_str(), rhs.length()); + } + } + + return *this; + } + + private: + // Constructs a non-NULL String from the given content. This + // function can only be called when c_str_ has not been allocated. + // ConstructNonNull(NULL, 0) results in an empty string (""). + // ConstructNonNull(NULL, non_zero) is undefined behavior. + void ConstructNonNull(const char* buffer, size_t a_length) { + char* const str = new char[a_length + 1]; + memcpy(str, buffer, a_length); + str[a_length] = '\0'; + c_str_ = str; + length_ = a_length; + } + + const char* c_str_; + size_t length_; +}; // class String + +// Streams a String to an ostream. Each '\0' character in the String +// is replaced with "\\0". +inline ::std::ostream& operator<<(::std::ostream& os, const String& str) { + if (str.c_str() == NULL) { + os << "(null)"; + } else { + const char* const c_str = str.c_str(); + for (size_t i = 0; i != str.length(); i++) { + if (c_str[i] == '\0') { + os << "\\0"; + } else { + os << c_str[i]; + } + } + } + return os; +} + +// Gets the content of the stringstream's buffer as a String. Each '\0' +// character in the buffer is replaced with "\\0". +GTEST_API_ String StringStreamToString(::std::stringstream* stream); + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". + +// Declared here but defined in gtest.h, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable); + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: keith.ray@gmail.com (Keith Ray) +// +// Google Test filepath utilities +// +// This header file declares classes and functions used internally by +// Google Test. They are subject to change without notice. +// +// This file is #included in . +// Do not include this header file separately! + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ + + +namespace testing { +namespace internal { + +// FilePath - a class for file and directory pathname manipulation which +// handles platform-specific conventions (like the pathname separator). +// Used for helper functions for naming files in a directory for xml output. +// Except for Set methods, all methods are const or static, which provides an +// "immutable value object" -- useful for peace of mind. +// A FilePath with a value ending in a path separator ("like/this/") represents +// a directory, otherwise it is assumed to represent a file. In either case, +// it may or may not represent an actual file or directory in the file system. +// Names are NOT checked for syntax correctness -- no checking for illegal +// characters, malformed paths, etc. + +class GTEST_API_ FilePath { + public: + FilePath() : pathname_("") { } + FilePath(const FilePath& rhs) : pathname_(rhs.pathname_) { } + + explicit FilePath(const char* pathname) : pathname_(pathname) { + Normalize(); + } + + explicit FilePath(const String& pathname) : pathname_(pathname) { + Normalize(); + } + + FilePath& operator=(const FilePath& rhs) { + Set(rhs); + return *this; + } + + void Set(const FilePath& rhs) { + pathname_ = rhs.pathname_; + } + + String ToString() const { return pathname_; } + const char* c_str() const { return pathname_.c_str(); } + + // Returns the current working directory, or "" if unsuccessful. + static FilePath GetCurrentDir(); + + // Given directory = "dir", base_name = "test", number = 0, + // extension = "xml", returns "dir/test.xml". If number is greater + // than zero (e.g., 12), returns "dir/test_12.xml". + // On Windows platform, uses \ as the separator rather than /. + static FilePath MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension); + + // Given directory = "dir", relative_path = "test.xml", + // returns "dir/test.xml". + // On Windows, uses \ as the separator rather than /. + static FilePath ConcatPaths(const FilePath& directory, + const FilePath& relative_path); + + // Returns a pathname for a file that does not currently exist. The pathname + // will be directory/base_name.extension or + // directory/base_name_.extension if directory/base_name.extension + // already exists. The number will be incremented until a pathname is found + // that does not already exist. + // Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. + // There could be a race condition if two or more processes are calling this + // function at the same time -- they could both pick the same filename. + static FilePath GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension); + + // Returns true iff the path is NULL or "". + bool IsEmpty() const { return c_str() == NULL || *c_str() == '\0'; } + + // If input name has a trailing separator character, removes it and returns + // the name, otherwise return the name string unmodified. + // On Windows platform, uses \ as the separator, other platforms use /. + FilePath RemoveTrailingPathSeparator() const; + + // Returns a copy of the FilePath with the directory part removed. + // Example: FilePath("path/to/file").RemoveDirectoryName() returns + // FilePath("file"). If there is no directory part ("just_a_file"), it returns + // the FilePath unmodified. If there is no file part ("just_a_dir/") it + // returns an empty FilePath (""). + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveDirectoryName() const; + + // RemoveFileName returns the directory path with the filename removed. + // Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". + // If the FilePath is "a_file" or "/a_file", RemoveFileName returns + // FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does + // not have a file, like "just/a/dir/", it returns the FilePath unmodified. + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveFileName() const; + + // Returns a copy of the FilePath with the case-insensitive extension removed. + // Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns + // FilePath("dir/file"). If a case-insensitive extension is not + // found, returns a copy of the original FilePath. + FilePath RemoveExtension(const char* extension) const; + + // Creates directories so that path exists. Returns true if successful or if + // the directories already exist; returns false if unable to create + // directories for any reason. Will also return false if the FilePath does + // not represent a directory (that is, it doesn't end with a path separator). + bool CreateDirectoriesRecursively() const; + + // Create the directory so that path exists. Returns true if successful or + // if the directory already exists; returns false if unable to create the + // directory for any reason, including if the parent directory does not + // exist. Not named "CreateDirectory" because that's a macro on Windows. + bool CreateFolder() const; + + // Returns true if FilePath describes something in the file-system, + // either a file, directory, or whatever, and that something exists. + bool FileOrDirectoryExists() const; + + // Returns true if pathname describes a directory in the file-system + // that exists. + bool DirectoryExists() const; + + // Returns true if FilePath ends with a path separator, which indicates that + // it is intended to represent a directory. Returns false otherwise. + // This does NOT check that a directory (or file) actually exists. + bool IsDirectory() const; + + // Returns true if pathname describes a root directory. (Windows has one + // root directory per disk drive.) + bool IsRootDirectory() const; + + // Returns true if pathname describes an absolute path. + bool IsAbsolutePath() const; + + private: + // Replaces multiple consecutive separators with a single separator. + // For example, "bar///foo" becomes "bar/foo". Does not eliminate other + // redundancies that might be in a pathname involving "." or "..". + // + // A pathname with multiple consecutive separators may occur either through + // user error or as a result of some scripts or APIs that generate a pathname + // with a trailing separator. On other platforms the same API or script + // may NOT generate a pathname with a trailing "/". Then elsewhere that + // pathname may have another "/" and pathname components added to it, + // without checking for the separator already being there. + // The script language and operating system may allow paths like "foo//bar" + // but some of the functions in FilePath will not handle that correctly. In + // particular, RemoveTrailingPathSeparator() only removes one separator, and + // it is called in CreateDirectoriesRecursively() assuming that it will change + // a pathname from directory syntax (trailing separator) to filename syntax. + // + // On Windows this method also replaces the alternate path separator '/' with + // the primary path separator '\\', so that for example "bar\\/\\foo" becomes + // "bar\\foo". + + void Normalize(); + + // Returns a pointer to the last occurence of a valid path separator in + // the FilePath. On Windows, for example, both '/' and '\' are valid path + // separators. Returns NULL if no path separator was found. + const char* FindLastPathSeparator() const; + + String pathname_; +}; // class FilePath + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ +// This file was GENERATED by command: +// pump.py gtest-type-util.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Type utilities needed for implementing typed and type-parameterized +// tests. This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +// Currently we support at most 50 types in a list, and at most 50 +// type-parameterized tests in one type-parameterized test case. +// Please contact googletestframework@googlegroups.com if you need +// more. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ + + +// #ifdef __GNUC__ is too general here. It is possible to use gcc without using +// libstdc++ (which is where cxxabi.h comes from). +# ifdef __GLIBCXX__ +# include +# elif defined(__HP_aCC) +# include +# endif // __GLIBCXX__ + +namespace testing { +namespace internal { + +// GetTypeName() returns a human-readable name of type T. +// NB: This function is also used in Google Mock, so don't move it inside of +// the typed-test-only section below. +template +String GetTypeName() { +# if GTEST_HAS_RTTI + + const char* const name = typeid(T).name(); +# if defined(__GLIBCXX__) || defined(__HP_aCC) + int status = 0; + // gcc's implementation of typeid(T).name() mangles the type name, + // so we have to demangle it. +# ifdef __GLIBCXX__ + using abi::__cxa_demangle; +# endif // __GLIBCXX__ + char* const readable_name = __cxa_demangle(name, 0, 0, &status); + const String name_str(status == 0 ? readable_name : name); + free(readable_name); + return name_str; +# else + return name; +# endif // __GLIBCXX__ || __HP_aCC + +# else + + return ""; + +# endif // GTEST_HAS_RTTI +} + +#if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// AssertyTypeEq::type is defined iff T1 and T2 are the same +// type. This can be used as a compile-time assertion to ensure that +// two types are equal. + +template +struct AssertTypeEq; + +template +struct AssertTypeEq { + typedef bool type; +}; + +// A unique type used as the default value for the arguments of class +// template Types. This allows us to simulate variadic templates +// (e.g. Types, Type, and etc), which C++ doesn't +// support directly. +struct None {}; + +// The following family of struct and struct templates are used to +// represent type lists. In particular, TypesN +// represents a type list with N types (T1, T2, ..., and TN) in it. +// Except for Types0, every struct in the family has two member types: +// Head for the first type in the list, and Tail for the rest of the +// list. + +// The empty type list. +struct Types0 {}; + +// Type lists of length 1, 2, 3, and so on. + +template +struct Types1 { + typedef T1 Head; + typedef Types0 Tail; +}; +template +struct Types2 { + typedef T1 Head; + typedef Types1 Tail; +}; + +template +struct Types3 { + typedef T1 Head; + typedef Types2 Tail; +}; + +template +struct Types4 { + typedef T1 Head; + typedef Types3 Tail; +}; + +template +struct Types5 { + typedef T1 Head; + typedef Types4 Tail; +}; + +template +struct Types6 { + typedef T1 Head; + typedef Types5 Tail; +}; + +template +struct Types7 { + typedef T1 Head; + typedef Types6 Tail; +}; + +template +struct Types8 { + typedef T1 Head; + typedef Types7 Tail; +}; + +template +struct Types9 { + typedef T1 Head; + typedef Types8 Tail; +}; + +template +struct Types10 { + typedef T1 Head; + typedef Types9 Tail; +}; + +template +struct Types11 { + typedef T1 Head; + typedef Types10 Tail; +}; + +template +struct Types12 { + typedef T1 Head; + typedef Types11 Tail; +}; + +template +struct Types13 { + typedef T1 Head; + typedef Types12 Tail; +}; + +template +struct Types14 { + typedef T1 Head; + typedef Types13 Tail; +}; + +template +struct Types15 { + typedef T1 Head; + typedef Types14 Tail; +}; + +template +struct Types16 { + typedef T1 Head; + typedef Types15 Tail; +}; + +template +struct Types17 { + typedef T1 Head; + typedef Types16 Tail; +}; + +template +struct Types18 { + typedef T1 Head; + typedef Types17 Tail; +}; + +template +struct Types19 { + typedef T1 Head; + typedef Types18 Tail; +}; + +template +struct Types20 { + typedef T1 Head; + typedef Types19 Tail; +}; + +template +struct Types21 { + typedef T1 Head; + typedef Types20 Tail; +}; + +template +struct Types22 { + typedef T1 Head; + typedef Types21 Tail; +}; + +template +struct Types23 { + typedef T1 Head; + typedef Types22 Tail; +}; + +template +struct Types24 { + typedef T1 Head; + typedef Types23 Tail; +}; + +template +struct Types25 { + typedef T1 Head; + typedef Types24 Tail; +}; + +template +struct Types26 { + typedef T1 Head; + typedef Types25 Tail; +}; + +template +struct Types27 { + typedef T1 Head; + typedef Types26 Tail; +}; + +template +struct Types28 { + typedef T1 Head; + typedef Types27 Tail; +}; + +template +struct Types29 { + typedef T1 Head; + typedef Types28 Tail; +}; + +template +struct Types30 { + typedef T1 Head; + typedef Types29 Tail; +}; + +template +struct Types31 { + typedef T1 Head; + typedef Types30 Tail; +}; + +template +struct Types32 { + typedef T1 Head; + typedef Types31 Tail; +}; + +template +struct Types33 { + typedef T1 Head; + typedef Types32 Tail; +}; + +template +struct Types34 { + typedef T1 Head; + typedef Types33 Tail; +}; + +template +struct Types35 { + typedef T1 Head; + typedef Types34 Tail; +}; + +template +struct Types36 { + typedef T1 Head; + typedef Types35 Tail; +}; + +template +struct Types37 { + typedef T1 Head; + typedef Types36 Tail; +}; + +template +struct Types38 { + typedef T1 Head; + typedef Types37 Tail; +}; + +template +struct Types39 { + typedef T1 Head; + typedef Types38 Tail; +}; + +template +struct Types40 { + typedef T1 Head; + typedef Types39 Tail; +}; + +template +struct Types41 { + typedef T1 Head; + typedef Types40 Tail; +}; + +template +struct Types42 { + typedef T1 Head; + typedef Types41 Tail; +}; + +template +struct Types43 { + typedef T1 Head; + typedef Types42 Tail; +}; + +template +struct Types44 { + typedef T1 Head; + typedef Types43 Tail; +}; + +template +struct Types45 { + typedef T1 Head; + typedef Types44 Tail; +}; + +template +struct Types46 { + typedef T1 Head; + typedef Types45 Tail; +}; + +template +struct Types47 { + typedef T1 Head; + typedef Types46 Tail; +}; + +template +struct Types48 { + typedef T1 Head; + typedef Types47 Tail; +}; + +template +struct Types49 { + typedef T1 Head; + typedef Types48 Tail; +}; + +template +struct Types50 { + typedef T1 Head; + typedef Types49 Tail; +}; + + +} // namespace internal + +// We don't want to require the users to write TypesN<...> directly, +// as that would require them to count the length. Types<...> is much +// easier to write, but generates horrible messages when there is a +// compiler error, as gcc insists on printing out each template +// argument, even if it has the default value (this means Types +// will appear as Types in the compiler +// errors). +// +// Our solution is to combine the best part of the two approaches: a +// user would write Types, and Google Test will translate +// that to TypesN internally to make error messages +// readable. The translation is done by the 'type' member of the +// Types template. +template +struct Types { + typedef internal::Types50 type; +}; + +template <> +struct Types { + typedef internal::Types0 type; +}; +template +struct Types { + typedef internal::Types1 type; +}; +template +struct Types { + typedef internal::Types2 type; +}; +template +struct Types { + typedef internal::Types3 type; +}; +template +struct Types { + typedef internal::Types4 type; +}; +template +struct Types { + typedef internal::Types5 type; +}; +template +struct Types { + typedef internal::Types6 type; +}; +template +struct Types { + typedef internal::Types7 type; +}; +template +struct Types { + typedef internal::Types8 type; +}; +template +struct Types { + typedef internal::Types9 type; +}; +template +struct Types { + typedef internal::Types10 type; +}; +template +struct Types { + typedef internal::Types11 type; +}; +template +struct Types { + typedef internal::Types12 type; +}; +template +struct Types { + typedef internal::Types13 type; +}; +template +struct Types { + typedef internal::Types14 type; +}; +template +struct Types { + typedef internal::Types15 type; +}; +template +struct Types { + typedef internal::Types16 type; +}; +template +struct Types { + typedef internal::Types17 type; +}; +template +struct Types { + typedef internal::Types18 type; +}; +template +struct Types { + typedef internal::Types19 type; +}; +template +struct Types { + typedef internal::Types20 type; +}; +template +struct Types { + typedef internal::Types21 type; +}; +template +struct Types { + typedef internal::Types22 type; +}; +template +struct Types { + typedef internal::Types23 type; +}; +template +struct Types { + typedef internal::Types24 type; +}; +template +struct Types { + typedef internal::Types25 type; +}; +template +struct Types { + typedef internal::Types26 type; +}; +template +struct Types { + typedef internal::Types27 type; +}; +template +struct Types { + typedef internal::Types28 type; +}; +template +struct Types { + typedef internal::Types29 type; +}; +template +struct Types { + typedef internal::Types30 type; +}; +template +struct Types { + typedef internal::Types31 type; +}; +template +struct Types { + typedef internal::Types32 type; +}; +template +struct Types { + typedef internal::Types33 type; +}; +template +struct Types { + typedef internal::Types34 type; +}; +template +struct Types { + typedef internal::Types35 type; +}; +template +struct Types { + typedef internal::Types36 type; +}; +template +struct Types { + typedef internal::Types37 type; +}; +template +struct Types { + typedef internal::Types38 type; +}; +template +struct Types { + typedef internal::Types39 type; +}; +template +struct Types { + typedef internal::Types40 type; +}; +template +struct Types { + typedef internal::Types41 type; +}; +template +struct Types { + typedef internal::Types42 type; +}; +template +struct Types { + typedef internal::Types43 type; +}; +template +struct Types { + typedef internal::Types44 type; +}; +template +struct Types { + typedef internal::Types45 type; +}; +template +struct Types { + typedef internal::Types46 type; +}; +template +struct Types { + typedef internal::Types47 type; +}; +template +struct Types { + typedef internal::Types48 type; +}; +template +struct Types { + typedef internal::Types49 type; +}; + +namespace internal { + +# define GTEST_TEMPLATE_ template class + +// The template "selector" struct TemplateSel is used to +// represent Tmpl, which must be a class template with one type +// parameter, as a type. TemplateSel::Bind::type is defined +// as the type Tmpl. This allows us to actually instantiate the +// template "selected" by TemplateSel. +// +// This trick is necessary for simulating typedef for class templates, +// which C++ doesn't support directly. +template +struct TemplateSel { + template + struct Bind { + typedef Tmpl type; + }; +}; + +# define GTEST_BIND_(TmplSel, T) \ + TmplSel::template Bind::type + +// A unique struct template used as the default value for the +// arguments of class template Templates. This allows us to simulate +// variadic templates (e.g. Templates, Templates, +// and etc), which C++ doesn't support directly. +template +struct NoneT {}; + +// The following family of struct and struct templates are used to +// represent template lists. In particular, TemplatesN represents a list of N templates (T1, T2, ..., and TN). Except +// for Templates0, every struct in the family has two member types: +// Head for the selector of the first template in the list, and Tail +// for the rest of the list. + +// The empty template list. +struct Templates0 {}; + +// Template lists of length 1, 2, 3, and so on. + +template +struct Templates1 { + typedef TemplateSel Head; + typedef Templates0 Tail; +}; +template +struct Templates2 { + typedef TemplateSel Head; + typedef Templates1 Tail; +}; + +template +struct Templates3 { + typedef TemplateSel Head; + typedef Templates2 Tail; +}; + +template +struct Templates4 { + typedef TemplateSel Head; + typedef Templates3 Tail; +}; + +template +struct Templates5 { + typedef TemplateSel Head; + typedef Templates4 Tail; +}; + +template +struct Templates6 { + typedef TemplateSel Head; + typedef Templates5 Tail; +}; + +template +struct Templates7 { + typedef TemplateSel Head; + typedef Templates6 Tail; +}; + +template +struct Templates8 { + typedef TemplateSel Head; + typedef Templates7 Tail; +}; + +template +struct Templates9 { + typedef TemplateSel Head; + typedef Templates8 Tail; +}; + +template +struct Templates10 { + typedef TemplateSel Head; + typedef Templates9 Tail; +}; + +template +struct Templates11 { + typedef TemplateSel Head; + typedef Templates10 Tail; +}; + +template +struct Templates12 { + typedef TemplateSel Head; + typedef Templates11 Tail; +}; + +template +struct Templates13 { + typedef TemplateSel Head; + typedef Templates12 Tail; +}; + +template +struct Templates14 { + typedef TemplateSel Head; + typedef Templates13 Tail; +}; + +template +struct Templates15 { + typedef TemplateSel Head; + typedef Templates14 Tail; +}; + +template +struct Templates16 { + typedef TemplateSel Head; + typedef Templates15 Tail; +}; + +template +struct Templates17 { + typedef TemplateSel Head; + typedef Templates16 Tail; +}; + +template +struct Templates18 { + typedef TemplateSel Head; + typedef Templates17 Tail; +}; + +template +struct Templates19 { + typedef TemplateSel Head; + typedef Templates18 Tail; +}; + +template +struct Templates20 { + typedef TemplateSel Head; + typedef Templates19 Tail; +}; + +template +struct Templates21 { + typedef TemplateSel Head; + typedef Templates20 Tail; +}; + +template +struct Templates22 { + typedef TemplateSel Head; + typedef Templates21 Tail; +}; + +template +struct Templates23 { + typedef TemplateSel Head; + typedef Templates22 Tail; +}; + +template +struct Templates24 { + typedef TemplateSel Head; + typedef Templates23 Tail; +}; + +template +struct Templates25 { + typedef TemplateSel Head; + typedef Templates24 Tail; +}; + +template +struct Templates26 { + typedef TemplateSel Head; + typedef Templates25 Tail; +}; + +template +struct Templates27 { + typedef TemplateSel Head; + typedef Templates26 Tail; +}; + +template +struct Templates28 { + typedef TemplateSel Head; + typedef Templates27 Tail; +}; + +template +struct Templates29 { + typedef TemplateSel Head; + typedef Templates28 Tail; +}; + +template +struct Templates30 { + typedef TemplateSel Head; + typedef Templates29 Tail; +}; + +template +struct Templates31 { + typedef TemplateSel Head; + typedef Templates30 Tail; +}; + +template +struct Templates32 { + typedef TemplateSel Head; + typedef Templates31 Tail; +}; + +template +struct Templates33 { + typedef TemplateSel Head; + typedef Templates32 Tail; +}; + +template +struct Templates34 { + typedef TemplateSel Head; + typedef Templates33 Tail; +}; + +template +struct Templates35 { + typedef TemplateSel Head; + typedef Templates34 Tail; +}; + +template +struct Templates36 { + typedef TemplateSel Head; + typedef Templates35 Tail; +}; + +template +struct Templates37 { + typedef TemplateSel Head; + typedef Templates36 Tail; +}; + +template +struct Templates38 { + typedef TemplateSel Head; + typedef Templates37 Tail; +}; + +template +struct Templates39 { + typedef TemplateSel Head; + typedef Templates38 Tail; +}; + +template +struct Templates40 { + typedef TemplateSel Head; + typedef Templates39 Tail; +}; + +template +struct Templates41 { + typedef TemplateSel Head; + typedef Templates40 Tail; +}; + +template +struct Templates42 { + typedef TemplateSel Head; + typedef Templates41 Tail; +}; + +template +struct Templates43 { + typedef TemplateSel Head; + typedef Templates42 Tail; +}; + +template +struct Templates44 { + typedef TemplateSel Head; + typedef Templates43 Tail; +}; + +template +struct Templates45 { + typedef TemplateSel Head; + typedef Templates44 Tail; +}; + +template +struct Templates46 { + typedef TemplateSel Head; + typedef Templates45 Tail; +}; + +template +struct Templates47 { + typedef TemplateSel Head; + typedef Templates46 Tail; +}; + +template +struct Templates48 { + typedef TemplateSel Head; + typedef Templates47 Tail; +}; + +template +struct Templates49 { + typedef TemplateSel Head; + typedef Templates48 Tail; +}; + +template +struct Templates50 { + typedef TemplateSel Head; + typedef Templates49 Tail; +}; + + +// We don't want to require the users to write TemplatesN<...> directly, +// as that would require them to count the length. Templates<...> is much +// easier to write, but generates horrible messages when there is a +// compiler error, as gcc insists on printing out each template +// argument, even if it has the default value (this means Templates +// will appear as Templates in the compiler +// errors). +// +// Our solution is to combine the best part of the two approaches: a +// user would write Templates, and Google Test will translate +// that to TemplatesN internally to make error messages +// readable. The translation is done by the 'type' member of the +// Templates template. +template +struct Templates { + typedef Templates50 type; +}; + +template <> +struct Templates { + typedef Templates0 type; +}; +template +struct Templates { + typedef Templates1 type; +}; +template +struct Templates { + typedef Templates2 type; +}; +template +struct Templates { + typedef Templates3 type; +}; +template +struct Templates { + typedef Templates4 type; +}; +template +struct Templates { + typedef Templates5 type; +}; +template +struct Templates { + typedef Templates6 type; +}; +template +struct Templates { + typedef Templates7 type; +}; +template +struct Templates { + typedef Templates8 type; +}; +template +struct Templates { + typedef Templates9 type; +}; +template +struct Templates { + typedef Templates10 type; +}; +template +struct Templates { + typedef Templates11 type; +}; +template +struct Templates { + typedef Templates12 type; +}; +template +struct Templates { + typedef Templates13 type; +}; +template +struct Templates { + typedef Templates14 type; +}; +template +struct Templates { + typedef Templates15 type; +}; +template +struct Templates { + typedef Templates16 type; +}; +template +struct Templates { + typedef Templates17 type; +}; +template +struct Templates { + typedef Templates18 type; +}; +template +struct Templates { + typedef Templates19 type; +}; +template +struct Templates { + typedef Templates20 type; +}; +template +struct Templates { + typedef Templates21 type; +}; +template +struct Templates { + typedef Templates22 type; +}; +template +struct Templates { + typedef Templates23 type; +}; +template +struct Templates { + typedef Templates24 type; +}; +template +struct Templates { + typedef Templates25 type; +}; +template +struct Templates { + typedef Templates26 type; +}; +template +struct Templates { + typedef Templates27 type; +}; +template +struct Templates { + typedef Templates28 type; +}; +template +struct Templates { + typedef Templates29 type; +}; +template +struct Templates { + typedef Templates30 type; +}; +template +struct Templates { + typedef Templates31 type; +}; +template +struct Templates { + typedef Templates32 type; +}; +template +struct Templates { + typedef Templates33 type; +}; +template +struct Templates { + typedef Templates34 type; +}; +template +struct Templates { + typedef Templates35 type; +}; +template +struct Templates { + typedef Templates36 type; +}; +template +struct Templates { + typedef Templates37 type; +}; +template +struct Templates { + typedef Templates38 type; +}; +template +struct Templates { + typedef Templates39 type; +}; +template +struct Templates { + typedef Templates40 type; +}; +template +struct Templates { + typedef Templates41 type; +}; +template +struct Templates { + typedef Templates42 type; +}; +template +struct Templates { + typedef Templates43 type; +}; +template +struct Templates { + typedef Templates44 type; +}; +template +struct Templates { + typedef Templates45 type; +}; +template +struct Templates { + typedef Templates46 type; +}; +template +struct Templates { + typedef Templates47 type; +}; +template +struct Templates { + typedef Templates48 type; +}; +template +struct Templates { + typedef Templates49 type; +}; + +// The TypeList template makes it possible to use either a single type +// or a Types<...> list in TYPED_TEST_CASE() and +// INSTANTIATE_TYPED_TEST_CASE_P(). + +template +struct TypeList { typedef Types1 type; }; + +template +struct TypeList > { + typedef typename Types::type type; +}; + +#endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ + +// Due to C++ preprocessor weirdness, we need double indirection to +// concatenate two tokens when one of them is __LINE__. Writing +// +// foo ## __LINE__ +// +// will result in the token foo__LINE__, instead of foo followed by +// the current line number. For more details, see +// http://www.parashift.com/c++-faq-lite/misc-technical-issues.html#faq-39.6 +#define GTEST_CONCAT_TOKEN_(foo, bar) GTEST_CONCAT_TOKEN_IMPL_(foo, bar) +#define GTEST_CONCAT_TOKEN_IMPL_(foo, bar) foo ## bar + +// Google Test defines the testing::Message class to allow construction of +// test messages via the << operator. The idea is that anything +// streamable to std::ostream can be streamed to a testing::Message. +// This allows a user to use his own types in Google Test assertions by +// overloading the << operator. +// +// util/gtl/stl_logging-inl.h overloads << for STL containers. These +// overloads cannot be defined in the std namespace, as that will be +// undefined behavior. Therefore, they are defined in the global +// namespace instead. +// +// C++'s symbol lookup rule (i.e. Koenig lookup) says that these +// overloads are visible in either the std namespace or the global +// namespace, but not other namespaces, including the testing +// namespace which Google Test's Message class is in. +// +// To allow STL containers (and other types that has a << operator +// defined in the global namespace) to be used in Google Test assertions, +// testing::Message must access the custom << operator from the global +// namespace. Hence this helper function. +// +// Note: Jeffrey Yasskin suggested an alternative fix by "using +// ::operator<<;" in the definition of Message's operator<<. That fix +// doesn't require a helper function, but unfortunately doesn't +// compile with MSVC. +template +inline void GTestStreamToHelper(std::ostream* os, const T& val) { + *os << val; +} + +class ProtocolMessage; +namespace proto2 { class Message; } + +namespace testing { + +// Forward declarations. + +class AssertionResult; // Result of an assertion. +class Message; // Represents a failure message. +class Test; // Represents a test. +class TestInfo; // Information about a test. +class TestPartResult; // Result of a test part. +class UnitTest; // A collection of test cases. + +template +::std::string PrintToString(const T& value); + +namespace internal { + +struct TraceInfo; // Information about a trace point. +class ScopedTrace; // Implements scoped trace. +class TestInfoImpl; // Opaque implementation of TestInfo +class UnitTestImpl; // Opaque implementation of UnitTest + +// How many times InitGoogleTest() has been called. +extern int g_init_gtest_count; + +// The text used in failure messages to indicate the start of the +// stack trace. +GTEST_API_ extern const char kStackTraceMarker[]; + +// A secret type that Google Test users don't know about. It has no +// definition on purpose. Therefore it's impossible to create a +// Secret object, which is what we want. +class Secret; + +// Two overloaded helpers for checking at compile time whether an +// expression is a null pointer literal (i.e. NULL or any 0-valued +// compile-time integral constant). Their return values have +// different sizes, so we can use sizeof() to test which version is +// picked by the compiler. These helpers have no implementations, as +// we only need their signatures. +// +// Given IsNullLiteralHelper(x), the compiler will pick the first +// version if x can be implicitly converted to Secret*, and pick the +// second version otherwise. Since Secret is a secret and incomplete +// type, the only expression a user can write that has type Secret* is +// a null pointer literal. Therefore, we know that x is a null +// pointer literal if and only if the first version is picked by the +// compiler. +char IsNullLiteralHelper(Secret* p); +char (&IsNullLiteralHelper(...))[2]; // NOLINT + +// A compile-time bool constant that is true if and only if x is a +// null pointer literal (i.e. NULL or any 0-valued compile-time +// integral constant). +#ifdef GTEST_ELLIPSIS_NEEDS_POD_ +// We lose support for NULL detection where the compiler doesn't like +// passing non-POD classes through ellipsis (...). +# define GTEST_IS_NULL_LITERAL_(x) false +#else +# define GTEST_IS_NULL_LITERAL_(x) \ + (sizeof(::testing::internal::IsNullLiteralHelper(x)) == 1) +#endif // GTEST_ELLIPSIS_NEEDS_POD_ + +// Appends the user-supplied message to the Google-Test-generated message. +GTEST_API_ String AppendUserMessage(const String& gtest_msg, + const Message& user_msg); + +// A helper class for creating scoped traces in user programs. +class GTEST_API_ ScopedTrace { + public: + // The c'tor pushes the given source file location and message onto + // a trace stack maintained by Google Test. + ScopedTrace(const char* file, int line, const Message& message); + + // The d'tor pops the info pushed by the c'tor. + // + // Note that the d'tor is not virtual in order to be efficient. + // Don't inherit from ScopedTrace! + ~ScopedTrace(); + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedTrace); +} GTEST_ATTRIBUTE_UNUSED_; // A ScopedTrace object does its job in its + // c'tor and d'tor. Therefore it doesn't + // need to be used otherwise. + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". +// Declared here but defined in gtest.h, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable); + +// The Symbian compiler has a bug that prevents it from selecting the +// correct overload of FormatForComparisonFailureMessage (see below) +// unless we pass the first argument by reference. If we do that, +// however, Visual Age C++ 10.1 generates a compiler error. Therefore +// we only apply the work-around for Symbian. +#if defined(__SYMBIAN32__) +# define GTEST_CREF_WORKAROUND_ const& +#else +# define GTEST_CREF_WORKAROUND_ +#endif + +// When this operand is a const char* or char*, if the other operand +// is a ::std::string or ::string, we print this operand as a C string +// rather than a pointer (we do the same for wide strings); otherwise +// we print it as a pointer to be safe. + +// This internal macro is used to avoid duplicated code. +#define GTEST_FORMAT_IMPL_(operand2_type, operand1_printer)\ +inline String FormatForComparisonFailureMessage(\ + operand2_type::value_type* GTEST_CREF_WORKAROUND_ str, \ + const operand2_type& /*operand2*/) {\ + return operand1_printer(str);\ +}\ +inline String FormatForComparisonFailureMessage(\ + const operand2_type::value_type* GTEST_CREF_WORKAROUND_ str, \ + const operand2_type& /*operand2*/) {\ + return operand1_printer(str);\ +} + +GTEST_FORMAT_IMPL_(::std::string, String::ShowCStringQuoted) +#if GTEST_HAS_STD_WSTRING +GTEST_FORMAT_IMPL_(::std::wstring, String::ShowWideCStringQuoted) +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_STRING +GTEST_FORMAT_IMPL_(::string, String::ShowCStringQuoted) +#endif // GTEST_HAS_GLOBAL_STRING +#if GTEST_HAS_GLOBAL_WSTRING +GTEST_FORMAT_IMPL_(::wstring, String::ShowWideCStringQuoted) +#endif // GTEST_HAS_GLOBAL_WSTRING + +#undef GTEST_FORMAT_IMPL_ + +// The next four overloads handle the case where the operand being +// printed is a char/wchar_t pointer and the other operand is not a +// string/wstring object. In such cases, we just print the operand as +// a pointer to be safe. +#define GTEST_FORMAT_CHAR_PTR_IMPL_(CharType) \ + template \ + String FormatForComparisonFailureMessage(CharType* GTEST_CREF_WORKAROUND_ p, \ + const T&) { \ + return PrintToString(static_cast(p)); \ + } + +GTEST_FORMAT_CHAR_PTR_IMPL_(char) +GTEST_FORMAT_CHAR_PTR_IMPL_(const char) +GTEST_FORMAT_CHAR_PTR_IMPL_(wchar_t) +GTEST_FORMAT_CHAR_PTR_IMPL_(const wchar_t) + +#undef GTEST_FORMAT_CHAR_PTR_IMPL_ + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// expected_expression: "foo" +// actual_expression: "bar" +// expected_value: "5" +// actual_value: "6" +// +// The ignoring_case parameter is true iff the assertion is a +// *_STRCASEEQ*. When it's true, the string " (ignoring case)" will +// be inserted into the message. +GTEST_API_ AssertionResult EqFailure(const char* expected_expression, + const char* actual_expression, + const String& expected_value, + const String& actual_value, + bool ignoring_case); + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +GTEST_API_ String GetBoolAssertionFailureMessage( + const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value); + +// This template class represents an IEEE floating-point number +// (either single-precision or double-precision, depending on the +// template parameters). +// +// The purpose of this class is to do more sophisticated number +// comparison. (Due to round-off error, etc, it's very unlikely that +// two floating-points will be equal exactly. Hence a naive +// comparison by the == operation often doesn't work.) +// +// Format of IEEE floating-point: +// +// The most-significant bit being the leftmost, an IEEE +// floating-point looks like +// +// sign_bit exponent_bits fraction_bits +// +// Here, sign_bit is a single bit that designates the sign of the +// number. +// +// For float, there are 8 exponent bits and 23 fraction bits. +// +// For double, there are 11 exponent bits and 52 fraction bits. +// +// More details can be found at +// http://en.wikipedia.org/wiki/IEEE_floating-point_standard. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +template +class FloatingPoint { + public: + // Defines the unsigned integer type that has the same size as the + // floating point number. + typedef typename TypeWithSize::UInt Bits; + + // Constants. + + // # of bits in a number. + static const size_t kBitCount = 8*sizeof(RawType); + + // # of fraction bits in a number. + static const size_t kFractionBitCount = + std::numeric_limits::digits - 1; + + // # of exponent bits in a number. + static const size_t kExponentBitCount = kBitCount - 1 - kFractionBitCount; + + // The mask for the sign bit. + static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); + + // The mask for the fraction bits. + static const Bits kFractionBitMask = + ~static_cast(0) >> (kExponentBitCount + 1); + + // The mask for the exponent bits. + static const Bits kExponentBitMask = ~(kSignBitMask | kFractionBitMask); + + // How many ULP's (Units in the Last Place) we want to tolerate when + // comparing two numbers. The larger the value, the more error we + // allow. A 0 value means that two numbers must be exactly the same + // to be considered equal. + // + // The maximum error of a single floating-point operation is 0.5 + // units in the last place. On Intel CPU's, all floating-point + // calculations are done with 80-bit precision, while double has 64 + // bits. Therefore, 4 should be enough for ordinary use. + // + // See the following article for more details on ULP: + // http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm. + static const size_t kMaxUlps = 4; + + // Constructs a FloatingPoint from a raw floating-point number. + // + // On an Intel CPU, passing a non-normalized NAN (Not a Number) + // around may change its bits, although the new value is guaranteed + // to be also a NAN. Therefore, don't expect this constructor to + // preserve the bits in x when x is a NAN. + explicit FloatingPoint(const RawType& x) { u_.value_ = x; } + + // Static methods + + // Reinterprets a bit pattern as a floating-point number. + // + // This function is needed to test the AlmostEquals() method. + static RawType ReinterpretBits(const Bits bits) { + FloatingPoint fp(0); + fp.u_.bits_ = bits; + return fp.u_.value_; + } + + // Returns the floating-point number that represent positive infinity. + static RawType Infinity() { + return ReinterpretBits(kExponentBitMask); + } + + // Non-static methods + + // Returns the bits that represents this number. + const Bits &bits() const { return u_.bits_; } + + // Returns the exponent bits of this number. + Bits exponent_bits() const { return kExponentBitMask & u_.bits_; } + + // Returns the fraction bits of this number. + Bits fraction_bits() const { return kFractionBitMask & u_.bits_; } + + // Returns the sign bit of this number. + Bits sign_bit() const { return kSignBitMask & u_.bits_; } + + // Returns true iff this is NAN (not a number). + bool is_nan() const { + // It's a NAN if the exponent bits are all ones and the fraction + // bits are not entirely zeros. + return (exponent_bits() == kExponentBitMask) && (fraction_bits() != 0); + } + + // Returns true iff this number is at most kMaxUlps ULP's away from + // rhs. In particular, this function: + // + // - returns false if either number is (or both are) NAN. + // - treats really large numbers as almost equal to infinity. + // - thinks +0.0 and -0.0 are 0 DLP's apart. + bool AlmostEquals(const FloatingPoint& rhs) const { + // The IEEE standard says that any comparison operation involving + // a NAN must return false. + if (is_nan() || rhs.is_nan()) return false; + + return DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_) + <= kMaxUlps; + } + + private: + // The data type used to store the actual floating-point number. + union FloatingPointUnion { + RawType value_; // The raw floating-point number. + Bits bits_; // The bits that represent the number. + }; + + // Converts an integer from the sign-and-magnitude representation to + // the biased representation. More precisely, let N be 2 to the + // power of (kBitCount - 1), an integer x is represented by the + // unsigned number x + N. + // + // For instance, + // + // -N + 1 (the most negative number representable using + // sign-and-magnitude) is represented by 1; + // 0 is represented by N; and + // N - 1 (the biggest number representable using + // sign-and-magnitude) is represented by 2N - 1. + // + // Read http://en.wikipedia.org/wiki/Signed_number_representations + // for more details on signed number representations. + static Bits SignAndMagnitudeToBiased(const Bits &sam) { + if (kSignBitMask & sam) { + // sam represents a negative number. + return ~sam + 1; + } else { + // sam represents a positive number. + return kSignBitMask | sam; + } + } + + // Given two numbers in the sign-and-magnitude representation, + // returns the distance between them as an unsigned number. + static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1, + const Bits &sam2) { + const Bits biased1 = SignAndMagnitudeToBiased(sam1); + const Bits biased2 = SignAndMagnitudeToBiased(sam2); + return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); + } + + FloatingPointUnion u_; +}; + +// Typedefs the instances of the FloatingPoint template class that we +// care to use. +typedef FloatingPoint Float; +typedef FloatingPoint Double; + +// In order to catch the mistake of putting tests that use different +// test fixture classes in the same test case, we need to assign +// unique IDs to fixture classes and compare them. The TypeId type is +// used to hold such IDs. The user should treat TypeId as an opaque +// type: the only operation allowed on TypeId values is to compare +// them for equality using the == operator. +typedef const void* TypeId; + +template +class TypeIdHelper { + public: + // dummy_ must not have a const type. Otherwise an overly eager + // compiler (e.g. MSVC 7.1 & 8.0) may try to merge + // TypeIdHelper::dummy_ for different Ts as an "optimization". + static bool dummy_; +}; + +template +bool TypeIdHelper::dummy_ = false; + +// GetTypeId() returns the ID of type T. Different values will be +// returned for different types. Calling the function twice with the +// same type argument is guaranteed to return the same ID. +template +TypeId GetTypeId() { + // The compiler is required to allocate a different + // TypeIdHelper::dummy_ variable for each T used to instantiate + // the template. Therefore, the address of dummy_ is guaranteed to + // be unique. + return &(TypeIdHelper::dummy_); +} + +// Returns the type ID of ::testing::Test. Always call this instead +// of GetTypeId< ::testing::Test>() to get the type ID of +// ::testing::Test, as the latter may give the wrong result due to a +// suspected linker bug when compiling Google Test as a Mac OS X +// framework. +GTEST_API_ TypeId GetTestTypeId(); + +// Defines the abstract factory interface that creates instances +// of a Test object. +class TestFactoryBase { + public: + virtual ~TestFactoryBase() {} + + // Creates a test instance to run. The instance is both created and destroyed + // within TestInfoImpl::Run() + virtual Test* CreateTest() = 0; + + protected: + TestFactoryBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestFactoryBase); +}; + +// This class provides implementation of TeastFactoryBase interface. +// It is used in TEST and TEST_F macros. +template +class TestFactoryImpl : public TestFactoryBase { + public: + virtual Test* CreateTest() { return new TestClass; } +}; + +#if GTEST_OS_WINDOWS + +// Predicate-formatters for implementing the HRESULT checking macros +// {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED} +// We pass a long instead of HRESULT to avoid causing an +// include dependency for the HRESULT type. +GTEST_API_ AssertionResult IsHRESULTSuccess(const char* expr, + long hr); // NOLINT +GTEST_API_ AssertionResult IsHRESULTFailure(const char* expr, + long hr); // NOLINT + +#endif // GTEST_OS_WINDOWS + +// Types of SetUpTestCase() and TearDownTestCase() functions. +typedef void (*SetUpTestCaseFunc)(); +typedef void (*TearDownTestCaseFunc)(); + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_case_name: name of the test case +// name: name of the test +// type_param the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param text representation of the test's value parameter, +// or NULL if this is not a type-parameterized test. +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +GTEST_API_ TestInfo* MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + TypeId fixture_class_id, + SetUpTestCaseFunc set_up_tc, + TearDownTestCaseFunc tear_down_tc, + TestFactoryBase* factory); + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +GTEST_API_ bool SkipPrefix(const char* prefix, const char** pstr); + +#if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// State of the definition of a type-parameterized test case. +class GTEST_API_ TypedTestCasePState { + public: + TypedTestCasePState() : registered_(false) {} + + // Adds the given test name to defined_test_names_ and return true + // if the test case hasn't been registered; otherwise aborts the + // program. + bool AddTestName(const char* file, int line, const char* case_name, + const char* test_name) { + if (registered_) { + fprintf(stderr, "%s Test %s must be defined before " + "REGISTER_TYPED_TEST_CASE_P(%s, ...).\n", + FormatFileLocation(file, line).c_str(), test_name, case_name); + fflush(stderr); + posix::Abort(); + } + defined_test_names_.insert(test_name); + return true; + } + + // Verifies that registered_tests match the test names in + // defined_test_names_; returns registered_tests if successful, or + // aborts the program otherwise. + const char* VerifyRegisteredTestNames( + const char* file, int line, const char* registered_tests); + + private: + bool registered_; + ::std::set defined_test_names_; +}; + +// Skips to the first non-space char after the first comma in 'str'; +// returns NULL if no comma is found in 'str'. +inline const char* SkipComma(const char* str) { + const char* comma = strchr(str, ','); + if (comma == NULL) { + return NULL; + } + while (IsSpace(*(++comma))) {} + return comma; +} + +// Returns the prefix of 'str' before the first comma in it; returns +// the entire string if it contains no comma. +inline String GetPrefixUntilComma(const char* str) { + const char* comma = strchr(str, ','); + return comma == NULL ? String(str) : String(str, comma - str); +} + +// TypeParameterizedTest::Register() +// registers a list of type-parameterized tests with Google Test. The +// return value is insignificant - we just need to return something +// such that we can call this function in a namespace scope. +// +// Implementation note: The GTEST_TEMPLATE_ macro declares a template +// template parameter. It's defined in gtest-type-util.h. +template +class TypeParameterizedTest { + public: + // 'index' is the index of the test in the type list 'Types' + // specified in INSTANTIATE_TYPED_TEST_CASE_P(Prefix, TestCase, + // Types). Valid values for 'index' are [0, N - 1] where N is the + // length of Types. + static bool Register(const char* prefix, const char* case_name, + const char* test_names, int index) { + typedef typename Types::Head Type; + typedef Fixture FixtureClass; + typedef typename GTEST_BIND_(TestSel, Type) TestClass; + + // First, registers the first type-parameterized test in the type + // list. + MakeAndRegisterTestInfo( + String::Format("%s%s%s/%d", prefix, prefix[0] == '\0' ? "" : "/", + case_name, index).c_str(), + GetPrefixUntilComma(test_names).c_str(), + GetTypeName().c_str(), + NULL, // No value parameter. + GetTypeId(), + TestClass::SetUpTestCase, + TestClass::TearDownTestCase, + new TestFactoryImpl); + + // Next, recurses (at compile time) with the tail of the type list. + return TypeParameterizedTest + ::Register(prefix, case_name, test_names, index + 1); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTest { + public: + static bool Register(const char* /*prefix*/, const char* /*case_name*/, + const char* /*test_names*/, int /*index*/) { + return true; + } +}; + +// TypeParameterizedTestCase::Register() +// registers *all combinations* of 'Tests' and 'Types' with Google +// Test. The return value is insignificant - we just need to return +// something such that we can call this function in a namespace scope. +template +class TypeParameterizedTestCase { + public: + static bool Register(const char* prefix, const char* case_name, + const char* test_names) { + typedef typename Tests::Head Head; + + // First, register the first test in 'Test' for each type in 'Types'. + TypeParameterizedTest::Register( + prefix, case_name, test_names, 0); + + // Next, recurses (at compile time) with the tail of the test list. + return TypeParameterizedTestCase + ::Register(prefix, case_name, SkipComma(test_names)); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTestCase { + public: + static bool Register(const char* /*prefix*/, const char* /*case_name*/, + const char* /*test_names*/) { + return true; + } +}; + +#endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +GTEST_API_ String GetCurrentOsStackTraceExceptTop(UnitTest* unit_test, + int skip_count); + +// Helpers for suppressing warnings on unreachable code or constant +// condition. + +// Always returns true. +GTEST_API_ bool AlwaysTrue(); + +// Always returns false. +inline bool AlwaysFalse() { return !AlwaysTrue(); } + +// Helper for suppressing false warning from Clang on a const char* +// variable declared in a conditional expression always being NULL in +// the else branch. +struct GTEST_API_ ConstCharPtr { + ConstCharPtr(const char* str) : value(str) {} + operator bool() const { return true; } + const char* value; +}; + +// A simple Linear Congruential Generator for generating random +// numbers with a uniform distribution. Unlike rand() and srand(), it +// doesn't use global state (and therefore can't interfere with user +// code). Unlike rand_r(), it's portable. An LCG isn't very random, +// but it's good enough for our purposes. +class GTEST_API_ Random { + public: + static const UInt32 kMaxRange = 1u << 31; + + explicit Random(UInt32 seed) : state_(seed) {} + + void Reseed(UInt32 seed) { state_ = seed; } + + // Generates a random number from [0, range). Crashes if 'range' is + // 0 or greater than kMaxRange. + UInt32 Generate(UInt32 range); + + private: + UInt32 state_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(Random); +}; + +// Defining a variable of type CompileAssertTypesEqual will cause a +// compiler error iff T1 and T2 are different types. +template +struct CompileAssertTypesEqual; + +template +struct CompileAssertTypesEqual { +}; + +// Removes the reference from a type if it is a reference type, +// otherwise leaves it unchanged. This is the same as +// tr1::remove_reference, which is not widely available yet. +template +struct RemoveReference { typedef T type; }; // NOLINT +template +struct RemoveReference { typedef T type; }; // NOLINT + +// A handy wrapper around RemoveReference that works when the argument +// T depends on template parameters. +#define GTEST_REMOVE_REFERENCE_(T) \ + typename ::testing::internal::RemoveReference::type + +// Removes const from a type if it is a const type, otherwise leaves +// it unchanged. This is the same as tr1::remove_const, which is not +// widely available yet. +template +struct RemoveConst { typedef T type; }; // NOLINT +template +struct RemoveConst { typedef T type; }; // NOLINT + +// MSVC 8.0, Sun C++, and IBM XL C++ have a bug which causes the above +// definition to fail to remove the const in 'const int[3]' and 'const +// char[3][4]'. The following specialization works around the bug. +// However, it causes trouble with GCC and thus needs to be +// conditionally compiled. +#if defined(_MSC_VER) || defined(__SUNPRO_CC) || defined(__IBMCPP__) +template +struct RemoveConst { + typedef typename RemoveConst::type type[N]; +}; +#endif + +// A handy wrapper around RemoveConst that works when the argument +// T depends on template parameters. +#define GTEST_REMOVE_CONST_(T) \ + typename ::testing::internal::RemoveConst::type + +// Turns const U&, U&, const U, and U all into U. +#define GTEST_REMOVE_REFERENCE_AND_CONST_(T) \ + GTEST_REMOVE_CONST_(GTEST_REMOVE_REFERENCE_(T)) + +// Adds reference to a type if it is not a reference type, +// otherwise leaves it unchanged. This is the same as +// tr1::add_reference, which is not widely available yet. +template +struct AddReference { typedef T& type; }; // NOLINT +template +struct AddReference { typedef T& type; }; // NOLINT + +// A handy wrapper around AddReference that works when the argument T +// depends on template parameters. +#define GTEST_ADD_REFERENCE_(T) \ + typename ::testing::internal::AddReference::type + +// Adds a reference to const on top of T as necessary. For example, +// it transforms +// +// char ==> const char& +// const char ==> const char& +// char& ==> const char& +// const char& ==> const char& +// +// The argument T must depend on some template parameters. +#define GTEST_REFERENCE_TO_CONST_(T) \ + GTEST_ADD_REFERENCE_(const GTEST_REMOVE_REFERENCE_(T)) + +// ImplicitlyConvertible::value is a compile-time bool +// constant that's true iff type From can be implicitly converted to +// type To. +template +class ImplicitlyConvertible { + private: + // We need the following helper functions only for their types. + // They have no implementations. + + // MakeFrom() is an expression whose type is From. We cannot simply + // use From(), as the type From may not have a public default + // constructor. + static From MakeFrom(); + + // These two functions are overloaded. Given an expression + // Helper(x), the compiler will pick the first version if x can be + // implicitly converted to type To; otherwise it will pick the + // second version. + // + // The first version returns a value of size 1, and the second + // version returns a value of size 2. Therefore, by checking the + // size of Helper(x), which can be done at compile time, we can tell + // which version of Helper() is used, and hence whether x can be + // implicitly converted to type To. + static char Helper(To); + static char (&Helper(...))[2]; // NOLINT + + // We have to put the 'public' section after the 'private' section, + // or MSVC refuses to compile the code. + public: + // MSVC warns about implicitly converting from double to int for + // possible loss of data, so we need to temporarily disable the + // warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4244) // Temporarily disables warning 4244. + + static const bool value = + sizeof(Helper(ImplicitlyConvertible::MakeFrom())) == 1; +# pragma warning(pop) // Restores the warning state. +#elif defined(__BORLANDC__) + // C++Builder cannot use member overload resolution during template + // instantiation. The simplest workaround is to use its C++0x type traits + // functions (C++Builder 2009 and above only). + static const bool value = __is_convertible(From, To); +#else + static const bool value = + sizeof(Helper(ImplicitlyConvertible::MakeFrom())) == 1; +#endif // _MSV_VER +}; +template +const bool ImplicitlyConvertible::value; + +// IsAProtocolMessage::value is a compile-time bool constant that's +// true iff T is type ProtocolMessage, proto2::Message, or a subclass +// of those. +template +struct IsAProtocolMessage + : public bool_constant< + ImplicitlyConvertible::value || + ImplicitlyConvertible::value> { +}; + +// When the compiler sees expression IsContainerTest(0), if C is an +// STL-style container class, the first overload of IsContainerTest +// will be viable (since both C::iterator* and C::const_iterator* are +// valid types and NULL can be implicitly converted to them). It will +// be picked over the second overload as 'int' is a perfect match for +// the type of argument 0. If C::iterator or C::const_iterator is not +// a valid type, the first overload is not viable, and the second +// overload will be picked. Therefore, we can determine whether C is +// a container class by checking the type of IsContainerTest(0). +// The value of the expression is insignificant. +// +// Note that we look for both C::iterator and C::const_iterator. The +// reason is that C++ injects the name of a class as a member of the +// class itself (e.g. you can refer to class iterator as either +// 'iterator' or 'iterator::iterator'). If we look for C::iterator +// only, for example, we would mistakenly think that a class named +// iterator is an STL container. +// +// Also note that the simpler approach of overloading +// IsContainerTest(typename C::const_iterator*) and +// IsContainerTest(...) doesn't work with Visual Age C++ and Sun C++. +typedef int IsContainer; +template +IsContainer IsContainerTest(int /* dummy */, + typename C::iterator* /* it */ = NULL, + typename C::const_iterator* /* const_it */ = NULL) { + return 0; +} + +typedef char IsNotContainer; +template +IsNotContainer IsContainerTest(long /* dummy */) { return '\0'; } + +// EnableIf::type is void when 'Cond' is true, and +// undefined when 'Cond' is false. To use SFINAE to make a function +// overload only apply when a particular expression is true, add +// "typename EnableIf::type* = 0" as the last parameter. +template struct EnableIf; +template<> struct EnableIf { typedef void type; }; // NOLINT + +// Utilities for native arrays. + +// ArrayEq() compares two k-dimensional native arrays using the +// elements' operator==, where k can be any integer >= 0. When k is +// 0, ArrayEq() degenerates into comparing a single pair of values. + +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs); + +// This generic version is used when k is 0. +template +inline bool ArrayEq(const T& lhs, const U& rhs) { return lhs == rhs; } + +// This overload is used when k >= 1. +template +inline bool ArrayEq(const T(&lhs)[N], const U(&rhs)[N]) { + return internal::ArrayEq(lhs, N, rhs); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous ArrayEq() function, arrays with different sizes would +// lead to different copies of the template code. +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs) { + for (size_t i = 0; i != size; i++) { + if (!internal::ArrayEq(lhs[i], rhs[i])) + return false; + } + return true; +} + +// Finds the first element in the iterator range [begin, end) that +// equals elem. Element may be a native array type itself. +template +Iter ArrayAwareFind(Iter begin, Iter end, const Element& elem) { + for (Iter it = begin; it != end; ++it) { + if (internal::ArrayEq(*it, elem)) + return it; + } + return end; +} + +// CopyArray() copies a k-dimensional native array using the elements' +// operator=, where k can be any integer >= 0. When k is 0, +// CopyArray() degenerates into copying a single value. + +template +void CopyArray(const T* from, size_t size, U* to); + +// This generic version is used when k is 0. +template +inline void CopyArray(const T& from, U* to) { *to = from; } + +// This overload is used when k >= 1. +template +inline void CopyArray(const T(&from)[N], U(*to)[N]) { + internal::CopyArray(from, N, *to); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous CopyArray() function, arrays with different sizes +// would lead to different copies of the template code. +template +void CopyArray(const T* from, size_t size, U* to) { + for (size_t i = 0; i != size; i++) { + internal::CopyArray(from[i], to + i); + } +} + +// The relation between an NativeArray object (see below) and the +// native array it represents. +enum RelationToSource { + kReference, // The NativeArray references the native array. + kCopy // The NativeArray makes a copy of the native array and + // owns the copy. +}; + +// Adapts a native array to a read-only STL-style container. Instead +// of the complete STL container concept, this adaptor only implements +// members useful for Google Mock's container matchers. New members +// should be added as needed. To simplify the implementation, we only +// support Element being a raw type (i.e. having no top-level const or +// reference modifier). It's the client's responsibility to satisfy +// this requirement. Element can be an array type itself (hence +// multi-dimensional arrays are supported). +template +class NativeArray { + public: + // STL-style container typedefs. + typedef Element value_type; + typedef Element* iterator; + typedef const Element* const_iterator; + + // Constructs from a native array. + NativeArray(const Element* array, size_t count, RelationToSource relation) { + Init(array, count, relation); + } + + // Copy constructor. + NativeArray(const NativeArray& rhs) { + Init(rhs.array_, rhs.size_, rhs.relation_to_source_); + } + + ~NativeArray() { + // Ensures that the user doesn't instantiate NativeArray with a + // const or reference type. + static_cast(StaticAssertTypeEqHelper()); + if (relation_to_source_ == kCopy) + delete[] array_; + } + + // STL-style container methods. + size_t size() const { return size_; } + const_iterator begin() const { return array_; } + const_iterator end() const { return array_ + size_; } + bool operator==(const NativeArray& rhs) const { + return size() == rhs.size() && + ArrayEq(begin(), size(), rhs.begin()); + } + + private: + // Initializes this object; makes a copy of the input array if + // 'relation' is kCopy. + void Init(const Element* array, size_t a_size, RelationToSource relation) { + if (relation == kReference) { + array_ = array; + } else { + Element* const copy = new Element[a_size]; + CopyArray(array, a_size, copy); + array_ = copy; + } + size_ = a_size; + relation_to_source_ = relation; + } + + const Element* array_; + size_t size_; + RelationToSource relation_to_source_; + + GTEST_DISALLOW_ASSIGN_(NativeArray); +}; + +} // namespace internal +} // namespace testing + +#define GTEST_MESSAGE_AT_(file, line, message, result_type) \ + ::testing::internal::AssertHelper(result_type, file, line, message) \ + = ::testing::Message() + +#define GTEST_MESSAGE_(message, result_type) \ + GTEST_MESSAGE_AT_(__FILE__, __LINE__, message, result_type) + +#define GTEST_FATAL_FAILURE_(message) \ + return GTEST_MESSAGE_(message, ::testing::TestPartResult::kFatalFailure) + +#define GTEST_NONFATAL_FAILURE_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kNonFatalFailure) + +#define GTEST_SUCCESS_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kSuccess) + +// Suppresses MSVC warnings 4072 (unreachable code) for the code following +// statement if it returns or throws (or doesn't return or throw in some +// situations). +#define GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) \ + if (::testing::internal::AlwaysTrue()) { statement; } + +#define GTEST_TEST_THROW_(statement, expected_exception, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::ConstCharPtr gtest_msg = "") { \ + bool gtest_caught_expected = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (expected_exception const&) { \ + gtest_caught_expected = true; \ + } \ + catch (...) { \ + gtest_msg.value = \ + "Expected: " #statement " throws an exception of type " \ + #expected_exception ".\n Actual: it throws a different type."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + if (!gtest_caught_expected) { \ + gtest_msg.value = \ + "Expected: " #statement " throws an exception of type " \ + #expected_exception ".\n Actual: it throws nothing."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__): \ + fail(gtest_msg.value) + +#define GTEST_TEST_NO_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (...) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__): \ + fail("Expected: " #statement " doesn't throw an exception.\n" \ + " Actual: it throws.") + +#define GTEST_TEST_ANY_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + bool gtest_caught_any = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (...) { \ + gtest_caught_any = true; \ + } \ + if (!gtest_caught_any) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__): \ + fail("Expected: " #statement " throws an exception.\n" \ + " Actual: it doesn't.") + + +// Implements Boolean test assertions such as EXPECT_TRUE. expression can be +// either a boolean expression or an AssertionResult. text is a textual +// represenation of expression as it was passed into the EXPECT_TRUE. +#define GTEST_TEST_BOOLEAN_(expression, text, actual, expected, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar_ = \ + ::testing::AssertionResult(expression)) \ + ; \ + else \ + fail(::testing::internal::GetBoolAssertionFailureMessage(\ + gtest_ar_, text, #actual, #expected).c_str()) + +#define GTEST_TEST_NO_FATAL_FAILURE_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + ::testing::internal::HasNewFatalFailureHelper gtest_fatal_failure_checker; \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + if (gtest_fatal_failure_checker.has_new_fatal_failure()) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__): \ + fail("Expected: " #statement " doesn't generate new fatal " \ + "failures in the current thread.\n" \ + " Actual: it does.") + +// Expands to the name of the class that implements the given test. +#define GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + test_case_name##_##test_name##_Test + +// Helper macro for defining tests. +#define GTEST_TEST_(test_case_name, test_name, parent_class, parent_id)\ +class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) : public parent_class {\ + public:\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {}\ + private:\ + virtual void TestBody();\ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_;\ + GTEST_DISALLOW_COPY_AND_ASSIGN_(\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name));\ +};\ +\ +::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, test_name)\ + ::test_info_ =\ + ::testing::internal::MakeAndRegisterTestInfo(\ + #test_case_name, #test_name, NULL, NULL, \ + (parent_id), \ + parent_class::SetUpTestCase, \ + parent_class::TearDownTestCase, \ + new ::testing::internal::TestFactoryImpl<\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)>);\ +void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the public API for death tests. It is +// #included by gtest.h so a user doesn't need to include this +// directly. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines internal utilities needed for implementing +// death tests. They are subject to change without notice. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ + + +#include + +namespace testing { +namespace internal { + +GTEST_DECLARE_string_(internal_run_death_test); + +// Names of the flags (needed for parsing Google Test flags). +const char kDeathTestStyleFlag[] = "death_test_style"; +const char kDeathTestUseFork[] = "death_test_use_fork"; +const char kInternalRunDeathTestFlag[] = "internal_run_death_test"; + +#if GTEST_HAS_DEATH_TEST + +// DeathTest is a class that hides much of the complexity of the +// GTEST_DEATH_TEST_ macro. It is abstract; its static Create method +// returns a concrete class that depends on the prevailing death test +// style, as defined by the --gtest_death_test_style and/or +// --gtest_internal_run_death_test flags. + +// In describing the results of death tests, these terms are used with +// the corresponding definitions: +// +// exit status: The integer exit information in the format specified +// by wait(2) +// exit code: The integer code passed to exit(3), _exit(2), or +// returned from main() +class GTEST_API_ DeathTest { + public: + // Create returns false if there was an error determining the + // appropriate action to take for the current death test; for example, + // if the gtest_death_test_style flag is set to an invalid value. + // The LastMessage method will return a more detailed message in that + // case. Otherwise, the DeathTest pointer pointed to by the "test" + // argument is set. If the death test should be skipped, the pointer + // is set to NULL; otherwise, it is set to the address of a new concrete + // DeathTest object that controls the execution of the current test. + static bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test); + DeathTest(); + virtual ~DeathTest() { } + + // A helper class that aborts a death test when it's deleted. + class ReturnSentinel { + public: + explicit ReturnSentinel(DeathTest* test) : test_(test) { } + ~ReturnSentinel() { test_->Abort(TEST_ENCOUNTERED_RETURN_STATEMENT); } + private: + DeathTest* const test_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ReturnSentinel); + } GTEST_ATTRIBUTE_UNUSED_; + + // An enumeration of possible roles that may be taken when a death + // test is encountered. EXECUTE means that the death test logic should + // be executed immediately. OVERSEE means that the program should prepare + // the appropriate environment for a child process to execute the death + // test, then wait for it to complete. + enum TestRole { OVERSEE_TEST, EXECUTE_TEST }; + + // An enumeration of the three reasons that a test might be aborted. + enum AbortReason { + TEST_ENCOUNTERED_RETURN_STATEMENT, + TEST_THREW_EXCEPTION, + TEST_DID_NOT_DIE + }; + + // Assumes one of the above roles. + virtual TestRole AssumeRole() = 0; + + // Waits for the death test to finish and returns its status. + virtual int Wait() = 0; + + // Returns true if the death test passed; that is, the test process + // exited during the test, its exit status matches a user-supplied + // predicate, and its stderr output matches a user-supplied regular + // expression. + // The user-supplied predicate may be a macro expression rather + // than a function pointer or functor, or else Wait and Passed could + // be combined. + virtual bool Passed(bool exit_status_ok) = 0; + + // Signals that the death test did not die as expected. + virtual void Abort(AbortReason reason) = 0; + + // Returns a human-readable outcome message regarding the outcome of + // the last death test. + static const char* LastMessage(); + + static void set_last_death_test_message(const String& message); + + private: + // A string containing a description of the outcome of the last death test. + static String last_death_test_message_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DeathTest); +}; + +// Factory interface for death tests. May be mocked out for testing. +class DeathTestFactory { + public: + virtual ~DeathTestFactory() { } + virtual bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test) = 0; +}; + +// A concrete DeathTestFactory implementation for normal use. +class DefaultDeathTestFactory : public DeathTestFactory { + public: + virtual bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test); +}; + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +GTEST_API_ bool ExitedUnsuccessfully(int exit_status); + +// Traps C++ exceptions escaping statement and reports them as test +// failures. Note that trapping SEH exceptions is not implemented here. +# if GTEST_HAS_EXCEPTIONS +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } catch (const ::std::exception& gtest_exception) { \ + fprintf(\ + stderr, \ + "\n%s: Caught std::exception-derived exception escaping the " \ + "death test statement. Exception message: %s\n", \ + ::testing::internal::FormatFileLocation(__FILE__, __LINE__).c_str(), \ + gtest_exception.what()); \ + fflush(stderr); \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } catch (...) { \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } + +# else +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) + +# endif + +// This macro is for implementing ASSERT_DEATH*, EXPECT_DEATH*, +// ASSERT_EXIT*, and EXPECT_EXIT*. +# define GTEST_DEATH_TEST_(statement, predicate, regex, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + const ::testing::internal::RE& gtest_regex = (regex); \ + ::testing::internal::DeathTest* gtest_dt; \ + if (!::testing::internal::DeathTest::Create(#statement, >est_regex, \ + __FILE__, __LINE__, >est_dt)) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + if (gtest_dt != NULL) { \ + ::testing::internal::scoped_ptr< ::testing::internal::DeathTest> \ + gtest_dt_ptr(gtest_dt); \ + switch (gtest_dt->AssumeRole()) { \ + case ::testing::internal::DeathTest::OVERSEE_TEST: \ + if (!gtest_dt->Passed(predicate(gtest_dt->Wait()))) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + break; \ + case ::testing::internal::DeathTest::EXECUTE_TEST: { \ + ::testing::internal::DeathTest::ReturnSentinel \ + gtest_sentinel(gtest_dt); \ + GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, gtest_dt); \ + gtest_dt->Abort(::testing::internal::DeathTest::TEST_DID_NOT_DIE); \ + break; \ + } \ + default: \ + break; \ + } \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__): \ + fail(::testing::internal::DeathTest::LastMessage()) +// The symbol "fail" here expands to something into which a message +// can be streamed. + +// A class representing the parsed contents of the +// --gtest_internal_run_death_test flag, as it existed when +// RUN_ALL_TESTS was called. +class InternalRunDeathTestFlag { + public: + InternalRunDeathTestFlag(const String& a_file, + int a_line, + int an_index, + int a_write_fd) + : file_(a_file), line_(a_line), index_(an_index), + write_fd_(a_write_fd) {} + + ~InternalRunDeathTestFlag() { + if (write_fd_ >= 0) + posix::Close(write_fd_); + } + + String file() const { return file_; } + int line() const { return line_; } + int index() const { return index_; } + int write_fd() const { return write_fd_; } + + private: + String file_; + int line_; + int index_; + int write_fd_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(InternalRunDeathTestFlag); +}; + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag(); + +#else // GTEST_HAS_DEATH_TEST + +// This macro is used for implementing macros such as +// EXPECT_DEATH_IF_SUPPORTED and ASSERT_DEATH_IF_SUPPORTED on systems where +// death tests are not supported. Those macros must compile on such systems +// iff EXPECT_DEATH and ASSERT_DEATH compile with the same parameters on +// systems that support death tests. This allows one to write such a macro +// on a system that does not support death tests and be sure that it will +// compile on a death-test supporting system. +// +// Parameters: +// statement - A statement that a macro such as EXPECT_DEATH would test +// for program termination. This macro has to make sure this +// statement is compiled but not executed, to ensure that +// EXPECT_DEATH_IF_SUPPORTED compiles with a certain +// parameter iff EXPECT_DEATH compiles with it. +// regex - A regex that a macro such as EXPECT_DEATH would use to test +// the output of statement. This parameter has to be +// compiled but not evaluated by this macro, to ensure that +// this macro only accepts expressions that a macro such as +// EXPECT_DEATH would accept. +// terminator - Must be an empty statement for EXPECT_DEATH_IF_SUPPORTED +// and a return statement for ASSERT_DEATH_IF_SUPPORTED. +// This ensures that ASSERT_DEATH_IF_SUPPORTED will not +// compile inside functions where ASSERT_DEATH doesn't +// compile. +// +// The branch that has an always false condition is used to ensure that +// statement and regex are compiled (and thus syntactically correct) but +// never executed. The unreachable code macro protects the terminator +// statement from generating an 'unreachable code' warning in case +// statement unconditionally returns or throws. The Message constructor at +// the end allows the syntax of streaming additional messages into the +// macro, for compilational compatibility with EXPECT_DEATH/ASSERT_DEATH. +# define GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, terminator) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + GTEST_LOG_(WARNING) \ + << "Death tests are not supported on this platform.\n" \ + << "Statement '" #statement "' cannot be verified."; \ + } else if (::testing::internal::AlwaysFalse()) { \ + ::testing::internal::RE::PartialMatch(".*", (regex)); \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + terminator; \ + } else \ + ::testing::Message() + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ + +namespace testing { + +// This flag controls the style of death tests. Valid values are "threadsafe", +// meaning that the death test child process will re-execute the test binary +// from the start, running only a single death test, or "fast", +// meaning that the child process will execute the test logic immediately +// after forking. +GTEST_DECLARE_string_(death_test_style); + +#if GTEST_HAS_DEATH_TEST + +// The following macros are useful for writing death tests. + +// Here's what happens when an ASSERT_DEATH* or EXPECT_DEATH* is +// executed: +// +// 1. It generates a warning if there is more than one active +// thread. This is because it's safe to fork() or clone() only +// when there is a single thread. +// +// 2. The parent process clone()s a sub-process and runs the death +// test in it; the sub-process exits with code 0 at the end of the +// death test, if it hasn't exited already. +// +// 3. The parent process waits for the sub-process to terminate. +// +// 4. The parent process checks the exit code and error message of +// the sub-process. +// +// Examples: +// +// ASSERT_DEATH(server.SendMessage(56, "Hello"), "Invalid port number"); +// for (int i = 0; i < 5; i++) { +// EXPECT_DEATH(server.ProcessRequest(i), +// "Invalid request .* in ProcessRequest()") +// << "Failed to die on request " << i); +// } +// +// ASSERT_EXIT(server.ExitNow(), ::testing::ExitedWithCode(0), "Exiting"); +// +// bool KilledBySIGHUP(int exit_code) { +// return WIFSIGNALED(exit_code) && WTERMSIG(exit_code) == SIGHUP; +// } +// +// ASSERT_EXIT(client.HangUpServer(), KilledBySIGHUP, "Hanging up!"); +// +// On the regular expressions used in death tests: +// +// On POSIX-compliant systems (*nix), we use the library, +// which uses the POSIX extended regex syntax. +// +// On other platforms (e.g. Windows), we only support a simple regex +// syntax implemented as part of Google Test. This limited +// implementation should be enough most of the time when writing +// death tests; though it lacks many features you can find in PCRE +// or POSIX extended regex syntax. For example, we don't support +// union ("x|y"), grouping ("(xy)"), brackets ("[xy]"), and +// repetition count ("x{5,7}"), among others. +// +// Below is the syntax that we do support. We chose it to be a +// subset of both PCRE and POSIX extended regex, so it's easy to +// learn wherever you come from. In the following: 'A' denotes a +// literal character, period (.), or a single \\ escape sequence; +// 'x' and 'y' denote regular expressions; 'm' and 'n' are for +// natural numbers. +// +// c matches any literal character c +// \\d matches any decimal digit +// \\D matches any character that's not a decimal digit +// \\f matches \f +// \\n matches \n +// \\r matches \r +// \\s matches any ASCII whitespace, including \n +// \\S matches any character that's not a whitespace +// \\t matches \t +// \\v matches \v +// \\w matches any letter, _, or decimal digit +// \\W matches any character that \\w doesn't match +// \\c matches any literal character c, which must be a punctuation +// . matches any single character except \n +// A? matches 0 or 1 occurrences of A +// A* matches 0 or many occurrences of A +// A+ matches 1 or many occurrences of A +// ^ matches the beginning of a string (not that of each line) +// $ matches the end of a string (not that of each line) +// xy matches x followed by y +// +// If you accidentally use PCRE or POSIX extended regex features +// not implemented by us, you will get a run-time failure. In that +// case, please try to rewrite your regular expression within the +// above syntax. +// +// This implementation is *not* meant to be as highly tuned or robust +// as a compiled regex library, but should perform well enough for a +// death test, which already incurs significant overhead by launching +// a child process. +// +// Known caveats: +// +// A "threadsafe" style death test obtains the path to the test +// program from argv[0] and re-executes it in the sub-process. For +// simplicity, the current implementation doesn't search the PATH +// when launching the sub-process. This means that the user must +// invoke the test program via a path that contains at least one +// path separator (e.g. path/to/foo_test and +// /absolute/path/to/bar_test are fine, but foo_test is not). This +// is rarely a problem as people usually don't put the test binary +// directory in PATH. +// +// TODO(wan@google.com): make thread-safe death tests search the PATH. + +// Asserts that a given statement causes the program to exit, with an +// integer exit status that satisfies predicate, and emitting error output +// that matches regex. +# define ASSERT_EXIT(statement, predicate, regex) \ + GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_FATAL_FAILURE_) + +// Like ASSERT_EXIT, but continues on to successive tests in the +// test case, if any: +# define EXPECT_EXIT(statement, predicate, regex) \ + GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_NONFATAL_FAILURE_) + +// Asserts that a given statement causes the program to exit, either by +// explicitly exiting with a nonzero exit code or being killed by a +// signal, and emitting error output that matches regex. +# define ASSERT_DEATH(statement, regex) \ + ASSERT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) + +// Like ASSERT_DEATH, but continues on to successive tests in the +// test case, if any: +# define EXPECT_DEATH(statement, regex) \ + EXPECT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) + +// Two predicate classes that can be used in {ASSERT,EXPECT}_EXIT*: + +// Tests that an exit code describes a normal exit with a given exit code. +class GTEST_API_ ExitedWithCode { + public: + explicit ExitedWithCode(int exit_code); + bool operator()(int exit_status) const; + private: + // No implementation - assignment is unsupported. + void operator=(const ExitedWithCode& other); + + const int exit_code_; +}; + +# if !GTEST_OS_WINDOWS +// Tests that an exit code describes an exit due to termination by a +// given signal. +class GTEST_API_ KilledBySignal { + public: + explicit KilledBySignal(int signum); + bool operator()(int exit_status) const; + private: + const int signum_; +}; +# endif // !GTEST_OS_WINDOWS + +// EXPECT_DEBUG_DEATH asserts that the given statements die in debug mode. +// The death testing framework causes this to have interesting semantics, +// since the sideeffects of the call are only visible in opt mode, and not +// in debug mode. +// +// In practice, this can be used to test functions that utilize the +// LOG(DFATAL) macro using the following style: +// +// int DieInDebugOr12(int* sideeffect) { +// if (sideeffect) { +// *sideeffect = 12; +// } +// LOG(DFATAL) << "death"; +// return 12; +// } +// +// TEST(TestCase, TestDieOr12WorksInDgbAndOpt) { +// int sideeffect = 0; +// // Only asserts in dbg. +// EXPECT_DEBUG_DEATH(DieInDebugOr12(&sideeffect), "death"); +// +// #ifdef NDEBUG +// // opt-mode has sideeffect visible. +// EXPECT_EQ(12, sideeffect); +// #else +// // dbg-mode no visible sideeffect. +// EXPECT_EQ(0, sideeffect); +// #endif +// } +// +// This will assert that DieInDebugReturn12InOpt() crashes in debug +// mode, usually due to a DCHECK or LOG(DFATAL), but returns the +// appropriate fallback value (12 in this case) in opt mode. If you +// need to test that a function has appropriate side-effects in opt +// mode, include assertions against the side-effects. A general +// pattern for this is: +// +// EXPECT_DEBUG_DEATH({ +// // Side-effects here will have an effect after this statement in +// // opt mode, but none in debug mode. +// EXPECT_EQ(12, DieInDebugOr12(&sideeffect)); +// }, "death"); +// +# ifdef NDEBUG + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + do { statement; } while (::testing::internal::AlwaysFalse()) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + do { statement; } while (::testing::internal::AlwaysFalse()) + +# else + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + EXPECT_DEATH(statement, regex) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + ASSERT_DEATH(statement, regex) + +# endif // NDEBUG for EXPECT_DEBUG_DEATH +#endif // GTEST_HAS_DEATH_TEST + +// EXPECT_DEATH_IF_SUPPORTED(statement, regex) and +// ASSERT_DEATH_IF_SUPPORTED(statement, regex) expand to real death tests if +// death tests are supported; otherwise they just issue a warning. This is +// useful when you are combining death test assertions with normal test +// assertions in one test. +#if GTEST_HAS_DEATH_TEST +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + EXPECT_DEATH(statement, regex) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + ASSERT_DEATH(statement, regex) +#else +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, ) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, return) +#endif + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the Message class. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! + +#ifndef GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ +#define GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ + +#include + + +namespace testing { + +// The Message class works like an ostream repeater. +// +// Typical usage: +// +// 1. You stream a bunch of values to a Message object. +// It will remember the text in a stringstream. +// 2. Then you stream the Message object to an ostream. +// This causes the text in the Message to be streamed +// to the ostream. +// +// For example; +// +// testing::Message foo; +// foo << 1 << " != " << 2; +// std::cout << foo; +// +// will print "1 != 2". +// +// Message is not intended to be inherited from. In particular, its +// destructor is not virtual. +// +// Note that stringstream behaves differently in gcc and in MSVC. You +// can stream a NULL char pointer to it in the former, but not in the +// latter (it causes an access violation if you do). The Message +// class hides this difference by treating a NULL char pointer as +// "(null)". +class GTEST_API_ Message { + private: + // The type of basic IO manipulators (endl, ends, and flush) for + // narrow streams. + typedef std::ostream& (*BasicNarrowIoManip)(std::ostream&); + + public: + // Constructs an empty Message. + // We allocate the stringstream separately because otherwise each use of + // ASSERT/EXPECT in a procedure adds over 200 bytes to the procedure's + // stack frame leading to huge stack frames in some cases; gcc does not reuse + // the stack space. + Message() : ss_(new ::std::stringstream) { + // By default, we want there to be enough precision when printing + // a double to a Message. + *ss_ << std::setprecision(std::numeric_limits::digits10 + 2); + } + + // Copy constructor. + Message(const Message& msg) : ss_(new ::std::stringstream) { // NOLINT + *ss_ << msg.GetString(); + } + + // Constructs a Message from a C-string. + explicit Message(const char* str) : ss_(new ::std::stringstream) { + *ss_ << str; + } + +#if GTEST_OS_SYMBIAN + // Streams a value (either a pointer or not) to this object. + template + inline Message& operator <<(const T& value) { + StreamHelper(typename internal::is_pointer::type(), value); + return *this; + } +#else + // Streams a non-pointer value to this object. + template + inline Message& operator <<(const T& val) { + ::GTestStreamToHelper(ss_.get(), val); + return *this; + } + + // Streams a pointer value to this object. + // + // This function is an overload of the previous one. When you + // stream a pointer to a Message, this definition will be used as it + // is more specialized. (The C++ Standard, section + // [temp.func.order].) If you stream a non-pointer, then the + // previous definition will be used. + // + // The reason for this overload is that streaming a NULL pointer to + // ostream is undefined behavior. Depending on the compiler, you + // may get "0", "(nil)", "(null)", or an access violation. To + // ensure consistent result across compilers, we always treat NULL + // as "(null)". + template + inline Message& operator <<(T* const& pointer) { // NOLINT + if (pointer == NULL) { + *ss_ << "(null)"; + } else { + ::GTestStreamToHelper(ss_.get(), pointer); + } + return *this; + } +#endif // GTEST_OS_SYMBIAN + + // Since the basic IO manipulators are overloaded for both narrow + // and wide streams, we have to provide this specialized definition + // of operator <<, even though its body is the same as the + // templatized version above. Without this definition, streaming + // endl or other basic IO manipulators to Message will confuse the + // compiler. + Message& operator <<(BasicNarrowIoManip val) { + *ss_ << val; + return *this; + } + + // Instead of 1/0, we want to see true/false for bool values. + Message& operator <<(bool b) { + return *this << (b ? "true" : "false"); + } + + // These two overloads allow streaming a wide C string to a Message + // using the UTF-8 encoding. + Message& operator <<(const wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); + } + Message& operator <<(wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); + } + +#if GTEST_HAS_STD_WSTRING + // Converts the given wide string to a narrow string using the UTF-8 + // encoding, and streams the result to this Message object. + Message& operator <<(const ::std::wstring& wstr); +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_WSTRING + // Converts the given wide string to a narrow string using the UTF-8 + // encoding, and streams the result to this Message object. + Message& operator <<(const ::wstring& wstr); +#endif // GTEST_HAS_GLOBAL_WSTRING + + // Gets the text streamed to this object so far as a String. + // Each '\0' character in the buffer is replaced with "\\0". + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + internal::String GetString() const { + return internal::StringStreamToString(ss_.get()); + } + + private: + +#if GTEST_OS_SYMBIAN + // These are needed as the Nokia Symbian Compiler cannot decide between + // const T& and const T* in a function template. The Nokia compiler _can_ + // decide between class template specializations for T and T*, so a + // tr1::type_traits-like is_pointer works, and we can overload on that. + template + inline void StreamHelper(internal::true_type /*dummy*/, T* pointer) { + if (pointer == NULL) { + *ss_ << "(null)"; + } else { + ::GTestStreamToHelper(ss_.get(), pointer); + } + } + template + inline void StreamHelper(internal::false_type /*dummy*/, const T& value) { + ::GTestStreamToHelper(ss_.get(), value); + } +#endif // GTEST_OS_SYMBIAN + + // We'll hold the text streamed to this object here. + const internal::scoped_ptr< ::std::stringstream> ss_; + + // We declare (but don't implement) this to prevent the compiler + // from implementing the assignment operator. + void operator=(const Message&); +}; + +// Streams a Message to an ostream. +inline std::ostream& operator <<(std::ostream& os, const Message& sb) { + return os << sb.GetString(); +} + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ +// This file was GENERATED by command: +// pump.py gtest-param-test.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: vladl@google.com (Vlad Losev) +// +// Macros and functions for implementing parameterized tests +// in Google C++ Testing Framework (Google Test) +// +// This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +#ifndef GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ + + +// Value-parameterized tests allow you to test your code with different +// parameters without writing multiple copies of the same test. +// +// Here is how you use value-parameterized tests: + +#if 0 + +// To write value-parameterized tests, first you should define a fixture +// class. It is usually derived from testing::TestWithParam (see below for +// another inheritance scheme that's sometimes useful in more complicated +// class hierarchies), where the type of your parameter values. +// TestWithParam is itself derived from testing::Test. T can be any +// copyable type. If it's a raw pointer, you are responsible for managing the +// lifespan of the pointed values. + +class FooTest : public ::testing::TestWithParam { + // You can implement all the usual class fixture members here. +}; + +// Then, use the TEST_P macro to define as many parameterized tests +// for this fixture as you want. The _P suffix is for "parameterized" +// or "pattern", whichever you prefer to think. + +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} + +// Finally, you can use INSTANTIATE_TEST_CASE_P to instantiate the test +// case with any set of parameters you want. Google Test defines a number +// of functions for generating test parameters. They return what we call +// (surprise!) parameter generators. Here is a summary of them, which +// are all in the testing namespace: +// +// +// Range(begin, end [, step]) - Yields values {begin, begin+step, +// begin+step+step, ...}. The values do not +// include end. step defaults to 1. +// Values(v1, v2, ..., vN) - Yields values {v1, v2, ..., vN}. +// ValuesIn(container) - Yields values from a C-style array, an STL +// ValuesIn(begin,end) container, or an iterator range [begin, end). +// Bool() - Yields sequence {false, true}. +// Combine(g1, g2, ..., gN) - Yields all combinations (the Cartesian product +// for the math savvy) of the values generated +// by the N generators. +// +// For more details, see comments at the definitions of these functions below +// in this file. +// +// The following statement will instantiate tests from the FooTest test case +// each with parameter values "meeny", "miny", and "moe". + +INSTANTIATE_TEST_CASE_P(InstantiationName, + FooTest, + Values("meeny", "miny", "moe")); + +// To distinguish different instances of the pattern, (yes, you +// can instantiate it more then once) the first argument to the +// INSTANTIATE_TEST_CASE_P macro is a prefix that will be added to the +// actual test case name. Remember to pick unique prefixes for different +// instantiations. The tests from the instantiation above will have +// these names: +// +// * InstantiationName/FooTest.DoesBlah/0 for "meeny" +// * InstantiationName/FooTest.DoesBlah/1 for "miny" +// * InstantiationName/FooTest.DoesBlah/2 for "moe" +// * InstantiationName/FooTest.HasBlahBlah/0 for "meeny" +// * InstantiationName/FooTest.HasBlahBlah/1 for "miny" +// * InstantiationName/FooTest.HasBlahBlah/2 for "moe" +// +// You can use these names in --gtest_filter. +// +// This statement will instantiate all tests from FooTest again, each +// with parameter values "cat" and "dog": + +const char* pets[] = {"cat", "dog"}; +INSTANTIATE_TEST_CASE_P(AnotherInstantiationName, FooTest, ValuesIn(pets)); + +// The tests from the instantiation above will have these names: +// +// * AnotherInstantiationName/FooTest.DoesBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.DoesBlah/1 for "dog" +// * AnotherInstantiationName/FooTest.HasBlahBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.HasBlahBlah/1 for "dog" +// +// Please note that INSTANTIATE_TEST_CASE_P will instantiate all tests +// in the given test case, whether their definitions come before or +// AFTER the INSTANTIATE_TEST_CASE_P statement. +// +// Please also note that generator expressions (including parameters to the +// generators) are evaluated in InitGoogleTest(), after main() has started. +// This allows the user on one hand, to adjust generator parameters in order +// to dynamically determine a set of tests to run and on the other hand, +// give the user a chance to inspect the generated tests with Google Test +// reflection API before RUN_ALL_TESTS() is executed. +// +// You can see samples/sample7_unittest.cc and samples/sample8_unittest.cc +// for more examples. +// +// In the future, we plan to publish the API for defining new parameter +// generators. But for now this interface remains part of the internal +// implementation and is subject to change. +// +// +// A parameterized test fixture must be derived from testing::Test and from +// testing::WithParamInterface, where T is the type of the parameter +// values. Inheriting from TestWithParam satisfies that requirement because +// TestWithParam inherits from both Test and WithParamInterface. In more +// complicated hierarchies, however, it is occasionally useful to inherit +// separately from Test and WithParamInterface. For example: + +class BaseTest : public ::testing::Test { + // You can inherit all the usual members for a non-parameterized test + // fixture here. +}; + +class DerivedTest : public BaseTest, public ::testing::WithParamInterface { + // The usual test fixture members go here too. +}; + +TEST_F(BaseTest, HasFoo) { + // This is an ordinary non-parameterized test. +} + +TEST_P(DerivedTest, DoesBlah) { + // GetParam works just the same here as if you inherit from TestWithParam. + EXPECT_TRUE(foo.Blah(GetParam())); +} + +#endif // 0 + + +#if !GTEST_OS_SYMBIAN +# include +#endif + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: vladl@google.com (Vlad Losev) + +// Type and function utilities for implementing parameterized tests. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ + +#include +#include +#include + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. +// Copyright 2003 Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: Dan Egnor (egnor@google.com) +// +// A "smart" pointer type with reference tracking. Every pointer to a +// particular object is kept on a circular linked list. When the last pointer +// to an object is destroyed or reassigned, the object is deleted. +// +// Used properly, this deletes the object when the last reference goes away. +// There are several caveats: +// - Like all reference counting schemes, cycles lead to leaks. +// - Each smart pointer is actually two pointers (8 bytes instead of 4). +// - Every time a pointer is assigned, the entire list of pointers to that +// object is traversed. This class is therefore NOT SUITABLE when there +// will often be more than two or three pointers to a particular object. +// - References are only tracked as long as linked_ptr<> objects are copied. +// If a linked_ptr<> is converted to a raw pointer and back, BAD THINGS +// will happen (double deletion). +// +// A good use of this class is storing object references in STL containers. +// You can safely put linked_ptr<> in a vector<>. +// Other uses may not be as good. +// +// Note: If you use an incomplete type with linked_ptr<>, the class +// *containing* linked_ptr<> must have a constructor and destructor (even +// if they do nothing!). +// +// Bill Gibbons suggested we use something like this. +// +// Thread Safety: +// Unlike other linked_ptr implementations, in this implementation +// a linked_ptr object is thread-safe in the sense that: +// - it's safe to copy linked_ptr objects concurrently, +// - it's safe to copy *from* a linked_ptr and read its underlying +// raw pointer (e.g. via get()) concurrently, and +// - it's safe to write to two linked_ptrs that point to the same +// shared object concurrently. +// TODO(wan@google.com): rename this to safe_linked_ptr to avoid +// confusion with normal linked_ptr. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ + +#include +#include + + +namespace testing { +namespace internal { + +// Protects copying of all linked_ptr objects. +GTEST_API_ GTEST_DECLARE_STATIC_MUTEX_(g_linked_ptr_mutex); + +// This is used internally by all instances of linked_ptr<>. It needs to be +// a non-template class because different types of linked_ptr<> can refer to +// the same object (linked_ptr(obj) vs linked_ptr(obj)). +// So, it needs to be possible for different types of linked_ptr to participate +// in the same circular linked list, so we need a single class type here. +// +// DO NOT USE THIS CLASS DIRECTLY YOURSELF. Use linked_ptr. +class linked_ptr_internal { + public: + // Create a new circle that includes only this instance. + void join_new() { + next_ = this; + } + + // Many linked_ptr operations may change p.link_ for some linked_ptr + // variable p in the same circle as this object. Therefore we need + // to prevent two such operations from occurring concurrently. + // + // Note that different types of linked_ptr objects can coexist in a + // circle (e.g. linked_ptr, linked_ptr, and + // linked_ptr). Therefore we must use a single mutex to + // protect all linked_ptr objects. This can create serious + // contention in production code, but is acceptable in a testing + // framework. + + // Join an existing circle. + // L < g_linked_ptr_mutex + void join(linked_ptr_internal const* ptr) { + MutexLock lock(&g_linked_ptr_mutex); + + linked_ptr_internal const* p = ptr; + while (p->next_ != ptr) p = p->next_; + p->next_ = this; + next_ = ptr; + } + + // Leave whatever circle we're part of. Returns true if we were the + // last member of the circle. Once this is done, you can join() another. + // L < g_linked_ptr_mutex + bool depart() { + MutexLock lock(&g_linked_ptr_mutex); + + if (next_ == this) return true; + linked_ptr_internal const* p = next_; + while (p->next_ != this) p = p->next_; + p->next_ = next_; + return false; + } + + private: + mutable linked_ptr_internal const* next_; +}; + +template +class linked_ptr { + public: + typedef T element_type; + + // Take over ownership of a raw pointer. This should happen as soon as + // possible after the object is created. + explicit linked_ptr(T* ptr = NULL) { capture(ptr); } + ~linked_ptr() { depart(); } + + // Copy an existing linked_ptr<>, adding ourselves to the list of references. + template linked_ptr(linked_ptr const& ptr) { copy(&ptr); } + linked_ptr(linked_ptr const& ptr) { // NOLINT + assert(&ptr != this); + copy(&ptr); + } + + // Assignment releases the old value and acquires the new. + template linked_ptr& operator=(linked_ptr const& ptr) { + depart(); + copy(&ptr); + return *this; + } + + linked_ptr& operator=(linked_ptr const& ptr) { + if (&ptr != this) { + depart(); + copy(&ptr); + } + return *this; + } + + // Smart pointer members. + void reset(T* ptr = NULL) { + depart(); + capture(ptr); + } + T* get() const { return value_; } + T* operator->() const { return value_; } + T& operator*() const { return *value_; } + + bool operator==(T* p) const { return value_ == p; } + bool operator!=(T* p) const { return value_ != p; } + template + bool operator==(linked_ptr const& ptr) const { + return value_ == ptr.get(); + } + template + bool operator!=(linked_ptr const& ptr) const { + return value_ != ptr.get(); + } + + private: + template + friend class linked_ptr; + + T* value_; + linked_ptr_internal link_; + + void depart() { + if (link_.depart()) delete value_; + } + + void capture(T* ptr) { + value_ = ptr; + link_.join_new(); + } + + template void copy(linked_ptr const* ptr) { + value_ = ptr->get(); + if (value_) + link_.join(&ptr->link_); + else + link_.join_new(); + } +}; + +template inline +bool operator==(T* ptr, const linked_ptr& x) { + return ptr == x.get(); +} + +template inline +bool operator!=(T* ptr, const linked_ptr& x) { + return ptr != x.get(); +} + +// A function to convert T* into linked_ptr +// Doing e.g. make_linked_ptr(new FooBarBaz(arg)) is a shorter notation +// for linked_ptr >(new FooBarBaz(arg)) +template +linked_ptr make_linked_ptr(T* ptr) { + return linked_ptr(ptr); +} + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Google Test - The Google C++ Testing Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// A user can teach this function how to print a class type T by +// defining either operator<<() or PrintTo() in the namespace that +// defines T. More specifically, the FIRST defined function in the +// following list will be used (assuming T is defined in namespace +// foo): +// +// 1. foo::PrintTo(const T&, ostream*) +// 2. operator<<(ostream&, const T&) defined in either foo or the +// global namespace. +// +// If none of the above is defined, it will print the debug string of +// the value if it is a protocol buffer, or print the raw bytes in the +// value otherwise. +// +// To aid debugging: when T is a reference type, the address of the +// value is also printed; when T is a (const) char pointer, both the +// pointer value and the NUL-terminated string it points to are +// printed. +// +// We also provide some convenient wrappers: +// +// // Prints a value to a string. For a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// std::string ::testing::PrintToString(const T& value); +// +// // Prints a value tersely: for a reference type, the referenced +// // value (but not the address) is printed; for a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// void ::testing::internal::UniversalTersePrint(const T& value, ostream*); +// +// // Prints value using the type inferred by the compiler. The difference +// // from UniversalTersePrint() is that this function prints both the +// // pointer and the NUL-terminated string for a (const or not) char pointer. +// void ::testing::internal::UniversalPrint(const T& value, ostream*); +// +// // Prints the fields of a tuple tersely to a string vector, one +// // element for each field. Tuple support must be enabled in +// // gtest-port.h. +// std::vector UniversalTersePrintTupleFieldsToStrings( +// const Tuple& value); +// +// Known limitation: +// +// The print primitives print the elements of an STL-style container +// using the compiler-inferred type of *iter where iter is a +// const_iterator of the container. When const_iterator is an input +// iterator but not a forward iterator, this inferred type may not +// match value_type, and the print output may be incorrect. In +// practice, this is rarely a problem as for most containers +// const_iterator is a forward iterator. We'll fix this if there's an +// actual need for it. Note that this fix cannot rely on value_type +// being defined as many user-defined container types don't have +// value_type. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ + +#include // NOLINT +#include +#include +#include +#include + +namespace testing { + +// Definitions in the 'internal' and 'internal2' name spaces are +// subject to change without notice. DO NOT USE THEM IN USER CODE! +namespace internal2 { + +// Prints the given number of bytes in the given object to the given +// ostream. +GTEST_API_ void PrintBytesInObjectTo(const unsigned char* obj_bytes, + size_t count, + ::std::ostream* os); + +// For selecting which printer to use when a given type has neither << +// nor PrintTo(). +enum TypeKind { + kProtobuf, // a protobuf type + kConvertibleToInteger, // a type implicitly convertible to BiggestInt + // (e.g. a named or unnamed enum type) + kOtherType // anything else +}; + +// TypeWithoutFormatter::PrintValue(value, os) is called +// by the universal printer to print a value of type T when neither +// operator<< nor PrintTo() is defined for T, where kTypeKind is the +// "kind" of T as defined by enum TypeKind. +template +class TypeWithoutFormatter { + public: + // This default version is called when kTypeKind is kOtherType. + static void PrintValue(const T& value, ::std::ostream* os) { + PrintBytesInObjectTo(reinterpret_cast(&value), + sizeof(value), os); + } +}; + +// We print a protobuf using its ShortDebugString() when the string +// doesn't exceed this many characters; otherwise we print it using +// DebugString() for better readability. +const size_t kProtobufOneLinerMaxLength = 50; + +template +class TypeWithoutFormatter { + public: + static void PrintValue(const T& value, ::std::ostream* os) { + const ::testing::internal::string short_str = value.ShortDebugString(); + const ::testing::internal::string pretty_str = + short_str.length() <= kProtobufOneLinerMaxLength ? + short_str : ("\n" + value.DebugString()); + *os << ("<" + pretty_str + ">"); + } +}; + +template +class TypeWithoutFormatter { + public: + // Since T has no << operator or PrintTo() but can be implicitly + // converted to BiggestInt, we print it as a BiggestInt. + // + // Most likely T is an enum type (either named or unnamed), in which + // case printing it as an integer is the desired behavior. In case + // T is not an enum, printing it as an integer is the best we can do + // given that it has no user-defined printer. + static void PrintValue(const T& value, ::std::ostream* os) { + const internal::BiggestInt kBigInt = value; + *os << kBigInt; + } +}; + +// Prints the given value to the given ostream. If the value is a +// protocol message, its debug string is printed; if it's an enum or +// of a type implicitly convertible to BiggestInt, it's printed as an +// integer; otherwise the bytes in the value are printed. This is +// what UniversalPrinter::Print() does when it knows nothing about +// type T and T has neither << operator nor PrintTo(). +// +// A user can override this behavior for a class type Foo by defining +// a << operator in the namespace where Foo is defined. +// +// We put this operator in namespace 'internal2' instead of 'internal' +// to simplify the implementation, as much code in 'internal' needs to +// use << in STL, which would conflict with our own << were it defined +// in 'internal'. +// +// Note that this operator<< takes a generic std::basic_ostream type instead of the more restricted std::ostream. If +// we define it to take an std::ostream instead, we'll get an +// "ambiguous overloads" compiler error when trying to print a type +// Foo that supports streaming to std::basic_ostream, as the compiler cannot tell whether +// operator<<(std::ostream&, const T&) or +// operator<<(std::basic_stream, const Foo&) is more +// specific. +template +::std::basic_ostream& operator<<( + ::std::basic_ostream& os, const T& x) { + TypeWithoutFormatter::value ? kProtobuf : + internal::ImplicitlyConvertible::value ? + kConvertibleToInteger : kOtherType)>::PrintValue(x, &os); + return os; +} + +} // namespace internal2 +} // namespace testing + +// This namespace MUST NOT BE NESTED IN ::testing, or the name look-up +// magic needed for implementing UniversalPrinter won't work. +namespace testing_internal { + +// Used to print a value that is not an STL-style container when the +// user doesn't define PrintTo() for it. +template +void DefaultPrintNonContainerTo(const T& value, ::std::ostream* os) { + // With the following statement, during unqualified name lookup, + // testing::internal2::operator<< appears as if it was declared in + // the nearest enclosing namespace that contains both + // ::testing_internal and ::testing::internal2, i.e. the global + // namespace. For more details, refer to the C++ Standard section + // 7.3.4-1 [namespace.udir]. This allows us to fall back onto + // testing::internal2::operator<< in case T doesn't come with a << + // operator. + // + // We cannot write 'using ::testing::internal2::operator<<;', which + // gcc 3.3 fails to compile due to a compiler bug. + using namespace ::testing::internal2; // NOLINT + + // Assuming T is defined in namespace foo, in the next statement, + // the compiler will consider all of: + // + // 1. foo::operator<< (thanks to Koenig look-up), + // 2. ::operator<< (as the current namespace is enclosed in ::), + // 3. testing::internal2::operator<< (thanks to the using statement above). + // + // The operator<< whose type matches T best will be picked. + // + // We deliberately allow #2 to be a candidate, as sometimes it's + // impossible to define #1 (e.g. when foo is ::std, defining + // anything in it is undefined behavior unless you are a compiler + // vendor.). + *os << value; +} + +} // namespace testing_internal + +namespace testing { +namespace internal { + +// UniversalPrinter::Print(value, ostream_ptr) prints the given +// value to the given ostream. The caller must ensure that +// 'ostream_ptr' is not NULL, or the behavior is undefined. +// +// We define UniversalPrinter as a class template (as opposed to a +// function template), as we need to partially specialize it for +// reference types, which cannot be done with function templates. +template +class UniversalPrinter; + +template +void UniversalPrint(const T& value, ::std::ostream* os); + +// Used to print an STL-style container when the user doesn't define +// a PrintTo() for it. +template +void DefaultPrintTo(IsContainer /* dummy */, + false_type /* is not a pointer */, + const C& container, ::std::ostream* os) { + const size_t kMaxCount = 32; // The maximum number of elements to print. + *os << '{'; + size_t count = 0; + for (typename C::const_iterator it = container.begin(); + it != container.end(); ++it, ++count) { + if (count > 0) { + *os << ','; + if (count == kMaxCount) { // Enough has been printed. + *os << " ..."; + break; + } + } + *os << ' '; + // We cannot call PrintTo(*it, os) here as PrintTo() doesn't + // handle *it being a native array. + internal::UniversalPrint(*it, os); + } + + if (count > 0) { + *os << ' '; + } + *os << '}'; +} + +// Used to print a pointer that is neither a char pointer nor a member +// pointer, when the user doesn't define PrintTo() for it. (A member +// variable pointer or member function pointer doesn't really point to +// a location in the address space. Their representation is +// implementation-defined. Therefore they will be printed as raw +// bytes.) +template +void DefaultPrintTo(IsNotContainer /* dummy */, + true_type /* is a pointer */, + T* p, ::std::ostream* os) { + if (p == NULL) { + *os << "NULL"; + } else { + // C++ doesn't allow casting from a function pointer to any object + // pointer. + // + // IsTrue() silences warnings: "Condition is always true", + // "unreachable code". + if (IsTrue(ImplicitlyConvertible::value)) { + // T is not a function type. We just call << to print p, + // relying on ADL to pick up user-defined << for their pointer + // types, if any. + *os << p; + } else { + // T is a function type, so '*os << p' doesn't do what we want + // (it just prints p as bool). We want to print p as a const + // void*. However, we cannot cast it to const void* directly, + // even using reinterpret_cast, as earlier versions of gcc + // (e.g. 3.4.5) cannot compile the cast when p is a function + // pointer. Casting to UInt64 first solves the problem. + *os << reinterpret_cast( + reinterpret_cast(p)); + } + } +} + +// Used to print a non-container, non-pointer value when the user +// doesn't define PrintTo() for it. +template +void DefaultPrintTo(IsNotContainer /* dummy */, + false_type /* is not a pointer */, + const T& value, ::std::ostream* os) { + ::testing_internal::DefaultPrintNonContainerTo(value, os); +} + +// Prints the given value using the << operator if it has one; +// otherwise prints the bytes in it. This is what +// UniversalPrinter::Print() does when PrintTo() is not specialized +// or overloaded for type T. +// +// A user can override this behavior for a class type Foo by defining +// an overload of PrintTo() in the namespace where Foo is defined. We +// give the user this option as sometimes defining a << operator for +// Foo is not desirable (e.g. the coding style may prevent doing it, +// or there is already a << operator but it doesn't do what the user +// wants). +template +void PrintTo(const T& value, ::std::ostream* os) { + // DefaultPrintTo() is overloaded. The type of its first two + // arguments determine which version will be picked. If T is an + // STL-style container, the version for container will be called; if + // T is a pointer, the pointer version will be called; otherwise the + // generic version will be called. + // + // Note that we check for container types here, prior to we check + // for protocol message types in our operator<<. The rationale is: + // + // For protocol messages, we want to give people a chance to + // override Google Mock's format by defining a PrintTo() or + // operator<<. For STL containers, other formats can be + // incompatible with Google Mock's format for the container + // elements; therefore we check for container types here to ensure + // that our format is used. + // + // The second argument of DefaultPrintTo() is needed to bypass a bug + // in Symbian's C++ compiler that prevents it from picking the right + // overload between: + // + // PrintTo(const T& x, ...); + // PrintTo(T* x, ...); + DefaultPrintTo(IsContainerTest(0), is_pointer(), value, os); +} + +// The following list of PrintTo() overloads tells +// UniversalPrinter::Print() how to print standard types (built-in +// types, strings, plain arrays, and pointers). + +// Overloads for various char types. +GTEST_API_ void PrintTo(unsigned char c, ::std::ostream* os); +GTEST_API_ void PrintTo(signed char c, ::std::ostream* os); +inline void PrintTo(char c, ::std::ostream* os) { + // When printing a plain char, we always treat it as unsigned. This + // way, the output won't be affected by whether the compiler thinks + // char is signed or not. + PrintTo(static_cast(c), os); +} + +// Overloads for other simple built-in types. +inline void PrintTo(bool x, ::std::ostream* os) { + *os << (x ? "true" : "false"); +} + +// Overload for wchar_t type. +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its decimal code (except for L'\0'). +// The L'\0' char is printed as "L'\\0'". The decimal code is printed +// as signed integer when wchar_t is implemented by the compiler +// as a signed type and is printed as an unsigned integer when wchar_t +// is implemented as an unsigned type. +GTEST_API_ void PrintTo(wchar_t wc, ::std::ostream* os); + +// Overloads for C strings. +GTEST_API_ void PrintTo(const char* s, ::std::ostream* os); +inline void PrintTo(char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// signed/unsigned char is often used for representing binary data, so +// we print pointers to it as void* to be safe. +inline void PrintTo(const signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(const unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// MSVC can be configured to define wchar_t as a typedef of unsigned +// short. It defines _NATIVE_WCHAR_T_DEFINED when wchar_t is a native +// type. When wchar_t is a typedef, defining an overload for const +// wchar_t* would cause unsigned short* be printed as a wide string, +// possibly causing invalid memory accesses. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Overloads for wide C strings +GTEST_API_ void PrintTo(const wchar_t* s, ::std::ostream* os); +inline void PrintTo(wchar_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +#endif + +// Overload for C arrays. Multi-dimensional arrays are printed +// properly. + +// Prints the given number of elements in an array, without printing +// the curly braces. +template +void PrintRawArrayTo(const T a[], size_t count, ::std::ostream* os) { + UniversalPrint(a[0], os); + for (size_t i = 1; i != count; i++) { + *os << ", "; + UniversalPrint(a[i], os); + } +} + +// Overloads for ::string and ::std::string. +#if GTEST_HAS_GLOBAL_STRING +GTEST_API_ void PrintStringTo(const ::string&s, ::std::ostream* os); +inline void PrintTo(const ::string& s, ::std::ostream* os) { + PrintStringTo(s, os); +} +#endif // GTEST_HAS_GLOBAL_STRING + +GTEST_API_ void PrintStringTo(const ::std::string&s, ::std::ostream* os); +inline void PrintTo(const ::std::string& s, ::std::ostream* os) { + PrintStringTo(s, os); +} + +// Overloads for ::wstring and ::std::wstring. +#if GTEST_HAS_GLOBAL_WSTRING +GTEST_API_ void PrintWideStringTo(const ::wstring&s, ::std::ostream* os); +inline void PrintTo(const ::wstring& s, ::std::ostream* os) { + PrintWideStringTo(s, os); +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +#if GTEST_HAS_STD_WSTRING +GTEST_API_ void PrintWideStringTo(const ::std::wstring&s, ::std::ostream* os); +inline void PrintTo(const ::std::wstring& s, ::std::ostream* os) { + PrintWideStringTo(s, os); +} +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_TR1_TUPLE +// Overload for ::std::tr1::tuple. Needed for printing function arguments, +// which are packed as tuples. + +// Helper function for printing a tuple. T must be instantiated with +// a tuple type. +template +void PrintTupleTo(const T& t, ::std::ostream* os); + +// Overloaded PrintTo() for tuples of various arities. We support +// tuples of up-to 10 fields. The following implementation works +// regardless of whether tr1::tuple is implemented using the +// non-standard variadic template feature or not. + +inline void PrintTo(const ::std::tr1::tuple<>& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo( + const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} +#endif // GTEST_HAS_TR1_TUPLE + +// Overload for std::pair. +template +void PrintTo(const ::std::pair& value, ::std::ostream* os) { + *os << '('; + // We cannot use UniversalPrint(value.first, os) here, as T1 may be + // a reference type. The same for printing value.second. + UniversalPrinter::Print(value.first, os); + *os << ", "; + UniversalPrinter::Print(value.second, os); + *os << ')'; +} + +// Implements printing a non-reference type T by letting the compiler +// pick the right overload of PrintTo() for T. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4180) // Temporarily disables warning 4180. +#endif // _MSC_VER + + // Note: we deliberately don't call this PrintTo(), as that name + // conflicts with ::testing::internal::PrintTo in the body of the + // function. + static void Print(const T& value, ::std::ostream* os) { + // By default, ::testing::internal::PrintTo() is used for printing + // the value. + // + // Thanks to Koenig look-up, if T is a class and has its own + // PrintTo() function defined in its namespace, that function will + // be visible here. Since it is more specific than the generic ones + // in ::testing::internal, it will be picked by the compiler in the + // following statement - exactly what we want. + PrintTo(value, os); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif // _MSC_VER +}; + +// UniversalPrintArray(begin, len, os) prints an array of 'len' +// elements, starting at address 'begin'. +template +void UniversalPrintArray(const T* begin, size_t len, ::std::ostream* os) { + if (len == 0) { + *os << "{}"; + } else { + *os << "{ "; + const size_t kThreshold = 18; + const size_t kChunkSize = 8; + // If the array has more than kThreshold elements, we'll have to + // omit some details by printing only the first and the last + // kChunkSize elements. + // TODO(wan@google.com): let the user control the threshold using a flag. + if (len <= kThreshold) { + PrintRawArrayTo(begin, len, os); + } else { + PrintRawArrayTo(begin, kChunkSize, os); + *os << ", ..., "; + PrintRawArrayTo(begin + len - kChunkSize, kChunkSize, os); + } + *os << " }"; + } +} +// This overload prints a (const) char array compactly. +GTEST_API_ void UniversalPrintArray(const char* begin, + size_t len, + ::std::ostream* os); + +// Implements printing an array type T[N]. +template +class UniversalPrinter { + public: + // Prints the given array, omitting some elements when there are too + // many. + static void Print(const T (&a)[N], ::std::ostream* os) { + UniversalPrintArray(a, N, os); + } +}; + +// Implements printing a reference type T&. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4180) // Temporarily disables warning 4180. +#endif // _MSC_VER + + static void Print(const T& value, ::std::ostream* os) { + // Prints the address of the value. We use reinterpret_cast here + // as static_cast doesn't compile when T is a function type. + *os << "@" << reinterpret_cast(&value) << " "; + + // Then prints the value itself. + UniversalPrint(value, os); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif // _MSC_VER +}; + +// Prints a value tersely: for a reference type, the referenced value +// (but not the address) is printed; for a (const) char pointer, the +// NUL-terminated string (but not the pointer) is printed. +template +void UniversalTersePrint(const T& value, ::std::ostream* os) { + UniversalPrint(value, os); +} +inline void UniversalTersePrint(const char* str, ::std::ostream* os) { + if (str == NULL) { + *os << "NULL"; + } else { + UniversalPrint(string(str), os); + } +} +inline void UniversalTersePrint(char* str, ::std::ostream* os) { + UniversalTersePrint(static_cast(str), os); +} + +// Prints a value using the type inferred by the compiler. The +// difference between this and UniversalTersePrint() is that for a +// (const) char pointer, this prints both the pointer and the +// NUL-terminated string. +template +void UniversalPrint(const T& value, ::std::ostream* os) { + UniversalPrinter::Print(value, os); +} + +#if GTEST_HAS_TR1_TUPLE +typedef ::std::vector Strings; + +// This helper template allows PrintTo() for tuples and +// UniversalTersePrintTupleFieldsToStrings() to be defined by +// induction on the number of tuple fields. The idea is that +// TuplePrefixPrinter::PrintPrefixTo(t, os) prints the first N +// fields in tuple t, and can be defined in terms of +// TuplePrefixPrinter. + +// The inductive case. +template +struct TuplePrefixPrinter { + // Prints the first N fields of a tuple. + template + static void PrintPrefixTo(const Tuple& t, ::std::ostream* os) { + TuplePrefixPrinter::PrintPrefixTo(t, os); + *os << ", "; + UniversalPrinter::type> + ::Print(::std::tr1::get(t), os); + } + + // Tersely prints the first N fields of a tuple to a string vector, + // one element for each field. + template + static void TersePrintPrefixToStrings(const Tuple& t, Strings* strings) { + TuplePrefixPrinter::TersePrintPrefixToStrings(t, strings); + ::std::stringstream ss; + UniversalTersePrint(::std::tr1::get(t), &ss); + strings->push_back(ss.str()); + } +}; + +// Base cases. +template <> +struct TuplePrefixPrinter<0> { + template + static void PrintPrefixTo(const Tuple&, ::std::ostream*) {} + + template + static void TersePrintPrefixToStrings(const Tuple&, Strings*) {} +}; +// We have to specialize the entire TuplePrefixPrinter<> class +// template here, even though the definition of +// TersePrintPrefixToStrings() is the same as the generic version, as +// Embarcadero (formerly CodeGear, formerly Borland) C++ doesn't +// support specializing a method template of a class template. +template <> +struct TuplePrefixPrinter<1> { + template + static void PrintPrefixTo(const Tuple& t, ::std::ostream* os) { + UniversalPrinter::type>:: + Print(::std::tr1::get<0>(t), os); + } + + template + static void TersePrintPrefixToStrings(const Tuple& t, Strings* strings) { + ::std::stringstream ss; + UniversalTersePrint(::std::tr1::get<0>(t), &ss); + strings->push_back(ss.str()); + } +}; + +// Helper function for printing a tuple. T must be instantiated with +// a tuple type. +template +void PrintTupleTo(const T& t, ::std::ostream* os) { + *os << "("; + TuplePrefixPrinter< ::std::tr1::tuple_size::value>:: + PrintPrefixTo(t, os); + *os << ")"; +} + +// Prints the fields of a tuple tersely to a string vector, one +// element for each field. See the comment before +// UniversalTersePrint() for how we define "tersely". +template +Strings UniversalTersePrintTupleFieldsToStrings(const Tuple& value) { + Strings result; + TuplePrefixPrinter< ::std::tr1::tuple_size::value>:: + TersePrintPrefixToStrings(value, &result); + return result; +} +#endif // GTEST_HAS_TR1_TUPLE + +} // namespace internal + +template +::std::string PrintToString(const T& value) { + ::std::stringstream ss; + internal::UniversalTersePrint(value, &ss); + return ss.str(); +} + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ + +#if GTEST_HAS_PARAM_TEST + +namespace testing { +namespace internal { + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Outputs a message explaining invalid registration of different +// fixture class for the same test case. This may happen when +// TEST_P macro is used to define two tests with the same name +// but in different namespaces. +GTEST_API_ void ReportInvalidTestCaseType(const char* test_case_name, + const char* file, int line); + +template class ParamGeneratorInterface; +template class ParamGenerator; + +// Interface for iterating over elements provided by an implementation +// of ParamGeneratorInterface. +template +class ParamIteratorInterface { + public: + virtual ~ParamIteratorInterface() {} + // A pointer to the base generator instance. + // Used only for the purposes of iterator comparison + // to make sure that two iterators belong to the same generator. + virtual const ParamGeneratorInterface* BaseGenerator() const = 0; + // Advances iterator to point to the next element + // provided by the generator. The caller is responsible + // for not calling Advance() on an iterator equal to + // BaseGenerator()->End(). + virtual void Advance() = 0; + // Clones the iterator object. Used for implementing copy semantics + // of ParamIterator. + virtual ParamIteratorInterface* Clone() const = 0; + // Dereferences the current iterator and provides (read-only) access + // to the pointed value. It is the caller's responsibility not to call + // Current() on an iterator equal to BaseGenerator()->End(). + // Used for implementing ParamGenerator::operator*(). + virtual const T* Current() const = 0; + // Determines whether the given iterator and other point to the same + // element in the sequence generated by the generator. + // Used for implementing ParamGenerator::operator==(). + virtual bool Equals(const ParamIteratorInterface& other) const = 0; +}; + +// Class iterating over elements provided by an implementation of +// ParamGeneratorInterface. It wraps ParamIteratorInterface +// and implements the const forward iterator concept. +template +class ParamIterator { + public: + typedef T value_type; + typedef const T& reference; + typedef ptrdiff_t difference_type; + + // ParamIterator assumes ownership of the impl_ pointer. + ParamIterator(const ParamIterator& other) : impl_(other.impl_->Clone()) {} + ParamIterator& operator=(const ParamIterator& other) { + if (this != &other) + impl_.reset(other.impl_->Clone()); + return *this; + } + + const T& operator*() const { return *impl_->Current(); } + const T* operator->() const { return impl_->Current(); } + // Prefix version of operator++. + ParamIterator& operator++() { + impl_->Advance(); + return *this; + } + // Postfix version of operator++. + ParamIterator operator++(int /*unused*/) { + ParamIteratorInterface* clone = impl_->Clone(); + impl_->Advance(); + return ParamIterator(clone); + } + bool operator==(const ParamIterator& other) const { + return impl_.get() == other.impl_.get() || impl_->Equals(*other.impl_); + } + bool operator!=(const ParamIterator& other) const { + return !(*this == other); + } + + private: + friend class ParamGenerator; + explicit ParamIterator(ParamIteratorInterface* impl) : impl_(impl) {} + scoped_ptr > impl_; +}; + +// ParamGeneratorInterface is the binary interface to access generators +// defined in other translation units. +template +class ParamGeneratorInterface { + public: + typedef T ParamType; + + virtual ~ParamGeneratorInterface() {} + + // Generator interface definition + virtual ParamIteratorInterface* Begin() const = 0; + virtual ParamIteratorInterface* End() const = 0; +}; + +// Wraps ParamGeneratorInterface and provides general generator syntax +// compatible with the STL Container concept. +// This class implements copy initialization semantics and the contained +// ParamGeneratorInterface instance is shared among all copies +// of the original object. This is possible because that instance is immutable. +template +class ParamGenerator { + public: + typedef ParamIterator iterator; + + explicit ParamGenerator(ParamGeneratorInterface* impl) : impl_(impl) {} + ParamGenerator(const ParamGenerator& other) : impl_(other.impl_) {} + + ParamGenerator& operator=(const ParamGenerator& other) { + impl_ = other.impl_; + return *this; + } + + iterator begin() const { return iterator(impl_->Begin()); } + iterator end() const { return iterator(impl_->End()); } + + private: + linked_ptr > impl_; +}; + +// Generates values from a range of two comparable values. Can be used to +// generate sequences of user-defined types that implement operator+() and +// operator<(). +// This class is used in the Range() function. +template +class RangeGenerator : public ParamGeneratorInterface { + public: + RangeGenerator(T begin, T end, IncrementT step) + : begin_(begin), end_(end), + step_(step), end_index_(CalculateEndIndex(begin, end, step)) {} + virtual ~RangeGenerator() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, begin_, 0, step_); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, end_, end_index_, step_); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, T value, int index, + IncrementT step) + : base_(base), value_(value), index_(index), step_(step) {} + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + virtual void Advance() { + value_ = value_ + step_; + index_++; + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const T* Current() const { return &value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const int other_index = + CheckedDowncastToActualType(&other)->index_; + return index_ == other_index; + } + + private: + Iterator(const Iterator& other) + : ParamIteratorInterface(), + base_(other.base_), value_(other.value_), index_(other.index_), + step_(other.step_) {} + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + T value_; + int index_; + const IncrementT step_; + }; // class RangeGenerator::Iterator + + static int CalculateEndIndex(const T& begin, + const T& end, + const IncrementT& step) { + int end_index = 0; + for (T i = begin; i < end; i = i + step) + end_index++; + return end_index; + } + + // No implementation - assignment is unsupported. + void operator=(const RangeGenerator& other); + + const T begin_; + const T end_; + const IncrementT step_; + // The index for the end() iterator. All the elements in the generated + // sequence are indexed (0-based) to aid iterator comparison. + const int end_index_; +}; // class RangeGenerator + + +// Generates values from a pair of STL-style iterators. Used in the +// ValuesIn() function. The elements are copied from the source range +// since the source can be located on the stack, and the generator +// is likely to persist beyond that stack frame. +template +class ValuesInIteratorRangeGenerator : public ParamGeneratorInterface { + public: + template + ValuesInIteratorRangeGenerator(ForwardIterator begin, ForwardIterator end) + : container_(begin, end) {} + virtual ~ValuesInIteratorRangeGenerator() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, container_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, container_.end()); + } + + private: + typedef typename ::std::vector ContainerType; + + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + typename ContainerType::const_iterator iterator) + : base_(base), iterator_(iterator) {} + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + virtual void Advance() { + ++iterator_; + value_.reset(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + // We need to use cached value referenced by iterator_ because *iterator_ + // can return a temporary object (and of type other then T), so just + // having "return &*iterator_;" doesn't work. + // value_ is updated here and not in Advance() because Advance() + // can advance iterator_ beyond the end of the range, and we cannot + // detect that fact. The client code, on the other hand, is + // responsible for not calling Current() on an out-of-range iterator. + virtual const T* Current() const { + if (value_.get() == NULL) + value_.reset(new T(*iterator_)); + return value_.get(); + } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + return iterator_ == + CheckedDowncastToActualType(&other)->iterator_; + } + + private: + Iterator(const Iterator& other) + // The explicit constructor call suppresses a false warning + // emitted by gcc when supplied with the -Wextra option. + : ParamIteratorInterface(), + base_(other.base_), + iterator_(other.iterator_) {} + + const ParamGeneratorInterface* const base_; + typename ContainerType::const_iterator iterator_; + // A cached value of *iterator_. We keep it here to allow access by + // pointer in the wrapping iterator's operator->(). + // value_ needs to be mutable to be accessed in Current(). + // Use of scoped_ptr helps manage cached value's lifetime, + // which is bound by the lifespan of the iterator itself. + mutable scoped_ptr value_; + }; // class ValuesInIteratorRangeGenerator::Iterator + + // No implementation - assignment is unsupported. + void operator=(const ValuesInIteratorRangeGenerator& other); + + const ContainerType container_; +}; // class ValuesInIteratorRangeGenerator + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Stores a parameter value and later creates tests parameterized with that +// value. +template +class ParameterizedTestFactory : public TestFactoryBase { + public: + typedef typename TestClass::ParamType ParamType; + explicit ParameterizedTestFactory(ParamType parameter) : + parameter_(parameter) {} + virtual Test* CreateTest() { + TestClass::SetParam(¶meter_); + return new TestClass(); + } + + private: + const ParamType parameter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactoryBase is a base class for meta-factories that create +// test factories for passing into MakeAndRegisterTestInfo function. +template +class TestMetaFactoryBase { + public: + virtual ~TestMetaFactoryBase() {} + + virtual TestFactoryBase* CreateTestFactory(ParamType parameter) = 0; +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactory creates test factories for passing into +// MakeAndRegisterTestInfo function. Since MakeAndRegisterTestInfo receives +// ownership of test factory pointer, same factory object cannot be passed +// into that method twice. But ParameterizedTestCaseInfo is going to call +// it for each Test/Parameter value combination. Thus it needs meta factory +// creator class. +template +class TestMetaFactory + : public TestMetaFactoryBase { + public: + typedef typename TestCase::ParamType ParamType; + + TestMetaFactory() {} + + virtual TestFactoryBase* CreateTestFactory(ParamType parameter) { + return new ParameterizedTestFactory(parameter); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestMetaFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseInfoBase is a generic interface +// to ParameterizedTestCaseInfo classes. ParameterizedTestCaseInfoBase +// accumulates test information provided by TEST_P macro invocations +// and generators provided by INSTANTIATE_TEST_CASE_P macro invocations +// and uses that information to register all resulting test instances +// in RegisterTests method. The ParameterizeTestCaseRegistry class holds +// a collection of pointers to the ParameterizedTestCaseInfo objects +// and calls RegisterTests() on each of them when asked. +class ParameterizedTestCaseInfoBase { + public: + virtual ~ParameterizedTestCaseInfoBase() {} + + // Base part of test case name for display purposes. + virtual const string& GetTestCaseName() const = 0; + // Test case id to verify identity. + virtual TypeId GetTestCaseTypeId() const = 0; + // UnitTest class invokes this method to register tests in this + // test case right before running them in RUN_ALL_TESTS macro. + // This method should not be called more then once on any single + // instance of a ParameterizedTestCaseInfoBase derived class. + virtual void RegisterTests() = 0; + + protected: + ParameterizedTestCaseInfoBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseInfoBase); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseInfo accumulates tests obtained from TEST_P +// macro invocations for a particular test case and generators +// obtained from INSTANTIATE_TEST_CASE_P macro invocations for that +// test case. It registers tests with all values generated by all +// generators when asked. +template +class ParameterizedTestCaseInfo : public ParameterizedTestCaseInfoBase { + public: + // ParamType and GeneratorCreationFunc are private types but are required + // for declarations of public methods AddTestPattern() and + // AddTestCaseInstantiation(). + typedef typename TestCase::ParamType ParamType; + // A function that returns an instance of appropriate generator type. + typedef ParamGenerator(GeneratorCreationFunc)(); + + explicit ParameterizedTestCaseInfo(const char* name) + : test_case_name_(name) {} + + // Test case base name for display purposes. + virtual const string& GetTestCaseName() const { return test_case_name_; } + // Test case id to verify identity. + virtual TypeId GetTestCaseTypeId() const { return GetTypeId(); } + // TEST_P macro uses AddTestPattern() to record information + // about a single test in a LocalTestInfo structure. + // test_case_name is the base name of the test case (without invocation + // prefix). test_base_name is the name of an individual test without + // parameter index. For the test SequenceA/FooTest.DoBar/1 FooTest is + // test case base name and DoBar is test base name. + void AddTestPattern(const char* test_case_name, + const char* test_base_name, + TestMetaFactoryBase* meta_factory) { + tests_.push_back(linked_ptr(new TestInfo(test_case_name, + test_base_name, + meta_factory))); + } + // INSTANTIATE_TEST_CASE_P macro uses AddGenerator() to record information + // about a generator. + int AddTestCaseInstantiation(const string& instantiation_name, + GeneratorCreationFunc* func, + const char* /* file */, + int /* line */) { + instantiations_.push_back(::std::make_pair(instantiation_name, func)); + return 0; // Return value used only to run this method in namespace scope. + } + // UnitTest class invokes this method to register tests in this test case + // test cases right before running tests in RUN_ALL_TESTS macro. + // This method should not be called more then once on any single + // instance of a ParameterizedTestCaseInfoBase derived class. + // UnitTest has a guard to prevent from calling this method more then once. + virtual void RegisterTests() { + for (typename TestInfoContainer::iterator test_it = tests_.begin(); + test_it != tests_.end(); ++test_it) { + linked_ptr test_info = *test_it; + for (typename InstantiationContainer::iterator gen_it = + instantiations_.begin(); gen_it != instantiations_.end(); + ++gen_it) { + const string& instantiation_name = gen_it->first; + ParamGenerator generator((*gen_it->second)()); + + Message test_case_name_stream; + if ( !instantiation_name.empty() ) + test_case_name_stream << instantiation_name << "/"; + test_case_name_stream << test_info->test_case_base_name; + + int i = 0; + for (typename ParamGenerator::iterator param_it = + generator.begin(); + param_it != generator.end(); ++param_it, ++i) { + Message test_name_stream; + test_name_stream << test_info->test_base_name << "/" << i; + MakeAndRegisterTestInfo( + test_case_name_stream.GetString().c_str(), + test_name_stream.GetString().c_str(), + NULL, // No type parameter. + PrintToString(*param_it).c_str(), + GetTestCaseTypeId(), + TestCase::SetUpTestCase, + TestCase::TearDownTestCase, + test_info->test_meta_factory->CreateTestFactory(*param_it)); + } // for param_it + } // for gen_it + } // for test_it + } // RegisterTests + + private: + // LocalTestInfo structure keeps information about a single test registered + // with TEST_P macro. + struct TestInfo { + TestInfo(const char* a_test_case_base_name, + const char* a_test_base_name, + TestMetaFactoryBase* a_test_meta_factory) : + test_case_base_name(a_test_case_base_name), + test_base_name(a_test_base_name), + test_meta_factory(a_test_meta_factory) {} + + const string test_case_base_name; + const string test_base_name; + const scoped_ptr > test_meta_factory; + }; + typedef ::std::vector > TestInfoContainer; + // Keeps pairs of + // received from INSTANTIATE_TEST_CASE_P macros. + typedef ::std::vector > + InstantiationContainer; + + const string test_case_name_; + TestInfoContainer tests_; + InstantiationContainer instantiations_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseInfo); +}; // class ParameterizedTestCaseInfo + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseRegistry contains a map of ParameterizedTestCaseInfoBase +// classes accessed by test case names. TEST_P and INSTANTIATE_TEST_CASE_P +// macros use it to locate their corresponding ParameterizedTestCaseInfo +// descriptors. +class ParameterizedTestCaseRegistry { + public: + ParameterizedTestCaseRegistry() {} + ~ParameterizedTestCaseRegistry() { + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + delete *it; + } + } + + // Looks up or creates and returns a structure containing information about + // tests and instantiations of a particular test case. + template + ParameterizedTestCaseInfo* GetTestCasePatternHolder( + const char* test_case_name, + const char* file, + int line) { + ParameterizedTestCaseInfo* typed_test_info = NULL; + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + if ((*it)->GetTestCaseName() == test_case_name) { + if ((*it)->GetTestCaseTypeId() != GetTypeId()) { + // Complain about incorrect usage of Google Test facilities + // and terminate the program since we cannot guaranty correct + // test case setup and tear-down in this case. + ReportInvalidTestCaseType(test_case_name, file, line); + posix::Abort(); + } else { + // At this point we are sure that the object we found is of the same + // type we are looking for, so we downcast it to that type + // without further checks. + typed_test_info = CheckedDowncastToActualType< + ParameterizedTestCaseInfo >(*it); + } + break; + } + } + if (typed_test_info == NULL) { + typed_test_info = new ParameterizedTestCaseInfo(test_case_name); + test_case_infos_.push_back(typed_test_info); + } + return typed_test_info; + } + void RegisterTests() { + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + (*it)->RegisterTests(); + } + } + + private: + typedef ::std::vector TestCaseInfoContainer; + + TestCaseInfoContainer test_case_infos_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseRegistry); +}; + +} // namespace internal +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ +// This file was GENERATED by command: +// pump.py gtest-param-util-generated.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: vladl@google.com (Vlad Losev) + +// Type and function utilities for implementing parameterized tests. +// This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +// Currently Google Test supports at most 50 arguments in Values, +// and at most 10 arguments in Combine. Please contact +// googletestframework@googlegroups.com if you need more. +// Please note that the number of arguments to Combine is limited +// by the maximum arity of the implementation of tr1::tuple which is +// currently set at 10. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. + +#if GTEST_HAS_PARAM_TEST + +namespace testing { + +// Forward declarations of ValuesIn(), which is implemented in +// include/gtest/gtest-param-test.h. +template +internal::ParamGenerator< + typename ::testing::internal::IteratorTraits::value_type> +ValuesIn(ForwardIterator begin, ForwardIterator end); + +template +internal::ParamGenerator ValuesIn(const T (&array)[N]); + +template +internal::ParamGenerator ValuesIn( + const Container& container); + +namespace internal { + +// Used in the Values() function to provide polymorphic capabilities. +template +class ValueArray1 { + public: + explicit ValueArray1(T1 v1) : v1_(v1) {} + + template + operator ParamGenerator() const { return ValuesIn(&v1_, &v1_ + 1); } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray1& other); + + const T1 v1_; +}; + +template +class ValueArray2 { + public: + ValueArray2(T1 v1, T2 v2) : v1_(v1), v2_(v2) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray2& other); + + const T1 v1_; + const T2 v2_; +}; + +template +class ValueArray3 { + public: + ValueArray3(T1 v1, T2 v2, T3 v3) : v1_(v1), v2_(v2), v3_(v3) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray3& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; +}; + +template +class ValueArray4 { + public: + ValueArray4(T1 v1, T2 v2, T3 v3, T4 v4) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray4& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; +}; + +template +class ValueArray5 { + public: + ValueArray5(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray5& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; +}; + +template +class ValueArray6 { + public: + ValueArray6(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray6& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; +}; + +template +class ValueArray7 { + public: + ValueArray7(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray7& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; +}; + +template +class ValueArray8 { + public: + ValueArray8(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray8& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; +}; + +template +class ValueArray9 { + public: + ValueArray9(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray9& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; +}; + +template +class ValueArray10 { + public: + ValueArray10(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray10& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; +}; + +template +class ValueArray11 { + public: + ValueArray11(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray11& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; +}; + +template +class ValueArray12 { + public: + ValueArray12(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray12& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; +}; + +template +class ValueArray13 { + public: + ValueArray13(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray13& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; +}; + +template +class ValueArray14 { + public: + ValueArray14(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray14& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; +}; + +template +class ValueArray15 { + public: + ValueArray15(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray15& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; +}; + +template +class ValueArray16 { + public: + ValueArray16(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray16& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; +}; + +template +class ValueArray17 { + public: + ValueArray17(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray17& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; +}; + +template +class ValueArray18 { + public: + ValueArray18(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray18& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; +}; + +template +class ValueArray19 { + public: + ValueArray19(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray19& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; +}; + +template +class ValueArray20 { + public: + ValueArray20(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray20& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; +}; + +template +class ValueArray21 { + public: + ValueArray21(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray21& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; +}; + +template +class ValueArray22 { + public: + ValueArray22(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray22& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; +}; + +template +class ValueArray23 { + public: + ValueArray23(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, + v23_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray23& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; +}; + +template +class ValueArray24 { + public: + ValueArray24(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray24& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; +}; + +template +class ValueArray25 { + public: + ValueArray25(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray25& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; +}; + +template +class ValueArray26 { + public: + ValueArray26(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray26& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; +}; + +template +class ValueArray27 { + public: + ValueArray27(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray27& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; +}; + +template +class ValueArray28 { + public: + ValueArray28(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray28& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; +}; + +template +class ValueArray29 { + public: + ValueArray29(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray29& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; +}; + +template +class ValueArray30 { + public: + ValueArray30(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray30& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; +}; + +template +class ValueArray31 { + public: + ValueArray31(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray31& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; +}; + +template +class ValueArray32 { + public: + ValueArray32(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray32& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; +}; + +template +class ValueArray33 { + public: + ValueArray33(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, + T33 v33) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray33& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; +}; + +template +class ValueArray34 { + public: + ValueArray34(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray34& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; +}; + +template +class ValueArray35 { + public: + ValueArray35(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), + v32_(v32), v33_(v33), v34_(v34), v35_(v35) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, + v35_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray35& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; +}; + +template +class ValueArray36 { + public: + ValueArray36(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), + v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray36& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; +}; + +template +class ValueArray37 { + public: + ValueArray37(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), + v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), + v36_(v36), v37_(v37) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray37& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; +}; + +template +class ValueArray38 { + public: + ValueArray38(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray38& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; +}; + +template +class ValueArray39 { + public: + ValueArray39(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray39& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; +}; + +template +class ValueArray40 { + public: + ValueArray40(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), + v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), + v40_(v40) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray40& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; +}; + +template +class ValueArray41 { + public: + ValueArray41(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, + T41 v41) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray41& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; +}; + +template +class ValueArray42 { + public: + ValueArray42(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray42& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; +}; + +template +class ValueArray43 { + public: + ValueArray43(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), + v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), + v38_(v38), v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray43& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; +}; + +template +class ValueArray44 { + public: + ValueArray44(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), + v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36), + v37_(v37), v38_(v38), v39_(v39), v40_(v40), v41_(v41), v42_(v42), + v43_(v43), v44_(v44) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray44& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; +}; + +template +class ValueArray45 { + public: + ValueArray45(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), + v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), + v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), v41_(v41), + v42_(v42), v43_(v43), v44_(v44), v45_(v45) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray45& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; +}; + +template +class ValueArray46 { + public: + ValueArray46(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), + v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), v46_(v46) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray46& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; +}; + +template +class ValueArray47 { + public: + ValueArray47(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), + v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), v46_(v46), + v47_(v47) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, + v47_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray47& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; +}; + +template +class ValueArray48 { + public: + ValueArray48(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), + v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), + v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), + v46_(v46), v47_(v47), v48_(v48) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray48& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; +}; + +template +class ValueArray49 { + public: + ValueArray49(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48, + T49 v49) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), + v45_(v45), v46_(v46), v47_(v47), v48_(v48), v49_(v49) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_, v49_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray49& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; + const T49 v49_; +}; + +template +class ValueArray50 { + public: + ValueArray50(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48, T49 v49, + T50 v50) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), + v45_(v45), v46_(v46), v47_(v47), v48_(v48), v49_(v49), v50_(v50) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_, v49_, v50_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray50& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; + const T49 v49_; + const T50 v50_; +}; + +# if GTEST_HAS_COMBINE +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Generates values from the Cartesian product of values produced +// by the argument generators. +// +template +class CartesianProductGenerator2 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator2(const ParamGenerator& g1, + const ParamGenerator& g2) + : g1_(g1), g2_(g2) {} + virtual ~CartesianProductGenerator2() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current2_; + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + ParamType current_value_; + }; // class CartesianProductGenerator2::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator2& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; +}; // class CartesianProductGenerator2 + + +template +class CartesianProductGenerator3 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator3(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3) + : g1_(g1), g2_(g2), g3_(g3) {} + virtual ~CartesianProductGenerator3() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current3_; + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + ParamType current_value_; + }; // class CartesianProductGenerator3::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator3& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; +}; // class CartesianProductGenerator3 + + +template +class CartesianProductGenerator4 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator4(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4) {} + virtual ~CartesianProductGenerator4() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current4_; + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + ParamType current_value_; + }; // class CartesianProductGenerator4::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator4& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; +}; // class CartesianProductGenerator4 + + +template +class CartesianProductGenerator5 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator5(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5) {} + virtual ~CartesianProductGenerator5() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current5_; + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + ParamType current_value_; + }; // class CartesianProductGenerator5::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator5& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; +}; // class CartesianProductGenerator5 + + +template +class CartesianProductGenerator6 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator6(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6) {} + virtual ~CartesianProductGenerator6() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current6_; + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + ParamType current_value_; + }; // class CartesianProductGenerator6::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator6& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; +}; // class CartesianProductGenerator6 + + +template +class CartesianProductGenerator7 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator7(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7) {} + virtual ~CartesianProductGenerator7() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current7_; + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + ParamType current_value_; + }; // class CartesianProductGenerator7::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator7& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; +}; // class CartesianProductGenerator7 + + +template +class CartesianProductGenerator8 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator8(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), + g8_(g8) {} + virtual ~CartesianProductGenerator8() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current8_; + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + ParamType current_value_; + }; // class CartesianProductGenerator8::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator8& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; +}; // class CartesianProductGenerator8 + + +template +class CartesianProductGenerator9 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator9(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8, const ParamGenerator& g9) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9) {} + virtual ~CartesianProductGenerator9() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin(), g9_, g9_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end(), g9_, g9_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8, + const ParamGenerator& g9, + const typename ParamGenerator::iterator& current9) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8), + begin9_(g9.begin()), end9_(g9.end()), current9_(current9) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current9_; + if (current9_ == end9_) { + current9_ = begin9_; + ++current8_; + } + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_ && + current9_ == typed_other->current9_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_), + begin9_(other.begin9_), + end9_(other.end9_), + current9_(other.current9_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_, + *current9_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_ || + current9_ == end9_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + const typename ParamGenerator::iterator begin9_; + const typename ParamGenerator::iterator end9_; + typename ParamGenerator::iterator current9_; + ParamType current_value_; + }; // class CartesianProductGenerator9::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator9& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; + const ParamGenerator g9_; +}; // class CartesianProductGenerator9 + + +template +class CartesianProductGenerator10 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator10(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8, const ParamGenerator& g9, + const ParamGenerator& g10) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9), g10_(g10) {} + virtual ~CartesianProductGenerator10() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin(), g9_, g9_.begin(), g10_, g10_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end(), g9_, g9_.end(), g10_, g10_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8, + const ParamGenerator& g9, + const typename ParamGenerator::iterator& current9, + const ParamGenerator& g10, + const typename ParamGenerator::iterator& current10) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8), + begin9_(g9.begin()), end9_(g9.end()), current9_(current9), + begin10_(g10.begin()), end10_(g10.end()), current10_(current10) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current10_; + if (current10_ == end10_) { + current10_ = begin10_; + ++current9_; + } + if (current9_ == end9_) { + current9_ = begin9_; + ++current8_; + } + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_ && + current9_ == typed_other->current9_ && + current10_ == typed_other->current10_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_), + begin9_(other.begin9_), + end9_(other.end9_), + current9_(other.current9_), + begin10_(other.begin10_), + end10_(other.end10_), + current10_(other.current10_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_, + *current9_, *current10_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_ || + current9_ == end9_ || + current10_ == end10_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + const typename ParamGenerator::iterator begin9_; + const typename ParamGenerator::iterator end9_; + typename ParamGenerator::iterator current9_; + const typename ParamGenerator::iterator begin10_; + const typename ParamGenerator::iterator end10_; + typename ParamGenerator::iterator current10_; + ParamType current_value_; + }; // class CartesianProductGenerator10::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator10& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; + const ParamGenerator g9_; + const ParamGenerator g10_; +}; // class CartesianProductGenerator10 + + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Helper classes providing Combine() with polymorphic features. They allow +// casting CartesianProductGeneratorN to ParamGenerator if T is +// convertible to U. +// +template +class CartesianProductHolder2 { + public: +CartesianProductHolder2(const Generator1& g1, const Generator2& g2) + : g1_(g1), g2_(g2) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator2( + static_cast >(g1_), + static_cast >(g2_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder2& other); + + const Generator1 g1_; + const Generator2 g2_; +}; // class CartesianProductHolder2 + +template +class CartesianProductHolder3 { + public: +CartesianProductHolder3(const Generator1& g1, const Generator2& g2, + const Generator3& g3) + : g1_(g1), g2_(g2), g3_(g3) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator3( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder3& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; +}; // class CartesianProductHolder3 + +template +class CartesianProductHolder4 { + public: +CartesianProductHolder4(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator4( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder4& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; +}; // class CartesianProductHolder4 + +template +class CartesianProductHolder5 { + public: +CartesianProductHolder5(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator5( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder5& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; +}; // class CartesianProductHolder5 + +template +class CartesianProductHolder6 { + public: +CartesianProductHolder6(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator6( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder6& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; +}; // class CartesianProductHolder6 + +template +class CartesianProductHolder7 { + public: +CartesianProductHolder7(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator7( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder7& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; +}; // class CartesianProductHolder7 + +template +class CartesianProductHolder8 { + public: +CartesianProductHolder8(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), + g8_(g8) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator8( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder8& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; +}; // class CartesianProductHolder8 + +template +class CartesianProductHolder9 { + public: +CartesianProductHolder9(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8, + const Generator9& g9) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator9( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_), + static_cast >(g9_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder9& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; + const Generator9 g9_; +}; // class CartesianProductHolder9 + +template +class CartesianProductHolder10 { + public: +CartesianProductHolder10(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8, + const Generator9& g9, const Generator10& g10) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9), g10_(g10) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator10( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_), + static_cast >(g9_), + static_cast >(g10_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder10& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; + const Generator9 g9_; + const Generator10 g10_; +}; // class CartesianProductHolder10 + +# endif // GTEST_HAS_COMBINE + +} // namespace internal +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ + +#if GTEST_HAS_PARAM_TEST + +namespace testing { + +// Functions producing parameter generators. +// +// Google Test uses these generators to produce parameters for value- +// parameterized tests. When a parameterized test case is instantiated +// with a particular generator, Google Test creates and runs tests +// for each element in the sequence produced by the generator. +// +// In the following sample, tests from test case FooTest are instantiated +// each three times with parameter values 3, 5, and 8: +// +// class FooTest : public TestWithParam { ... }; +// +// TEST_P(FooTest, TestThis) { +// } +// TEST_P(FooTest, TestThat) { +// } +// INSTANTIATE_TEST_CASE_P(TestSequence, FooTest, Values(3, 5, 8)); +// + +// Range() returns generators providing sequences of values in a range. +// +// Synopsis: +// Range(start, end) +// - returns a generator producing a sequence of values {start, start+1, +// start+2, ..., }. +// Range(start, end, step) +// - returns a generator producing a sequence of values {start, start+step, +// start+step+step, ..., }. +// Notes: +// * The generated sequences never include end. For example, Range(1, 5) +// returns a generator producing a sequence {1, 2, 3, 4}. Range(1, 9, 2) +// returns a generator producing {1, 3, 5, 7}. +// * start and end must have the same type. That type may be any integral or +// floating-point type or a user defined type satisfying these conditions: +// * It must be assignable (have operator=() defined). +// * It must have operator+() (operator+(int-compatible type) for +// two-operand version). +// * It must have operator<() defined. +// Elements in the resulting sequences will also have that type. +// * Condition start < end must be satisfied in order for resulting sequences +// to contain any elements. +// +template +internal::ParamGenerator Range(T start, T end, IncrementT step) { + return internal::ParamGenerator( + new internal::RangeGenerator(start, end, step)); +} + +template +internal::ParamGenerator Range(T start, T end) { + return Range(start, end, 1); +} + +// ValuesIn() function allows generation of tests with parameters coming from +// a container. +// +// Synopsis: +// ValuesIn(const T (&array)[N]) +// - returns a generator producing sequences with elements from +// a C-style array. +// ValuesIn(const Container& container) +// - returns a generator producing sequences with elements from +// an STL-style container. +// ValuesIn(Iterator begin, Iterator end) +// - returns a generator producing sequences with elements from +// a range [begin, end) defined by a pair of STL-style iterators. These +// iterators can also be plain C pointers. +// +// Please note that ValuesIn copies the values from the containers +// passed in and keeps them to generate tests in RUN_ALL_TESTS(). +// +// Examples: +// +// This instantiates tests from test case StringTest +// each with C-string values of "foo", "bar", and "baz": +// +// const char* strings[] = {"foo", "bar", "baz"}; +// INSTANTIATE_TEST_CASE_P(StringSequence, SrtingTest, ValuesIn(strings)); +// +// This instantiates tests from test case StlStringTest +// each with STL strings with values "a" and "b": +// +// ::std::vector< ::std::string> GetParameterStrings() { +// ::std::vector< ::std::string> v; +// v.push_back("a"); +// v.push_back("b"); +// return v; +// } +// +// INSTANTIATE_TEST_CASE_P(CharSequence, +// StlStringTest, +// ValuesIn(GetParameterStrings())); +// +// +// This will also instantiate tests from CharTest +// each with parameter values 'a' and 'b': +// +// ::std::list GetParameterChars() { +// ::std::list list; +// list.push_back('a'); +// list.push_back('b'); +// return list; +// } +// ::std::list l = GetParameterChars(); +// INSTANTIATE_TEST_CASE_P(CharSequence2, +// CharTest, +// ValuesIn(l.begin(), l.end())); +// +template +internal::ParamGenerator< + typename ::testing::internal::IteratorTraits::value_type> +ValuesIn(ForwardIterator begin, ForwardIterator end) { + typedef typename ::testing::internal::IteratorTraits + ::value_type ParamType; + return internal::ParamGenerator( + new internal::ValuesInIteratorRangeGenerator(begin, end)); +} + +template +internal::ParamGenerator ValuesIn(const T (&array)[N]) { + return ValuesIn(array, array + N); +} + +template +internal::ParamGenerator ValuesIn( + const Container& container) { + return ValuesIn(container.begin(), container.end()); +} + +// Values() allows generating tests from explicitly specified list of +// parameters. +// +// Synopsis: +// Values(T v1, T v2, ..., T vN) +// - returns a generator producing sequences with elements v1, v2, ..., vN. +// +// For example, this instantiates tests from test case BarTest each +// with values "one", "two", and "three": +// +// INSTANTIATE_TEST_CASE_P(NumSequence, BarTest, Values("one", "two", "three")); +// +// This instantiates tests from test case BazTest each with values 1, 2, 3.5. +// The exact type of values will depend on the type of parameter in BazTest. +// +// INSTANTIATE_TEST_CASE_P(FloatingNumbers, BazTest, Values(1, 2, 3.5)); +// +// Currently, Values() supports from 1 to 50 parameters. +// +template +internal::ValueArray1 Values(T1 v1) { + return internal::ValueArray1(v1); +} + +template +internal::ValueArray2 Values(T1 v1, T2 v2) { + return internal::ValueArray2(v1, v2); +} + +template +internal::ValueArray3 Values(T1 v1, T2 v2, T3 v3) { + return internal::ValueArray3(v1, v2, v3); +} + +template +internal::ValueArray4 Values(T1 v1, T2 v2, T3 v3, T4 v4) { + return internal::ValueArray4(v1, v2, v3, v4); +} + +template +internal::ValueArray5 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5) { + return internal::ValueArray5(v1, v2, v3, v4, v5); +} + +template +internal::ValueArray6 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6) { + return internal::ValueArray6(v1, v2, v3, v4, v5, v6); +} + +template +internal::ValueArray7 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7) { + return internal::ValueArray7(v1, v2, v3, v4, v5, + v6, v7); +} + +template +internal::ValueArray8 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8) { + return internal::ValueArray8(v1, v2, v3, v4, + v5, v6, v7, v8); +} + +template +internal::ValueArray9 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9) { + return internal::ValueArray9(v1, v2, v3, + v4, v5, v6, v7, v8, v9); +} + +template +internal::ValueArray10 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10) { + return internal::ValueArray10(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10); +} + +template +internal::ValueArray11 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11) { + return internal::ValueArray11(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11); +} + +template +internal::ValueArray12 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12) { + return internal::ValueArray12(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12); +} + +template +internal::ValueArray13 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13) { + return internal::ValueArray13(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13); +} + +template +internal::ValueArray14 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14) { + return internal::ValueArray14(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14); +} + +template +internal::ValueArray15 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15) { + return internal::ValueArray15(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15); +} + +template +internal::ValueArray16 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16) { + return internal::ValueArray16(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16); +} + +template +internal::ValueArray17 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17) { + return internal::ValueArray17(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17); +} + +template +internal::ValueArray18 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18) { + return internal::ValueArray18(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18); +} + +template +internal::ValueArray19 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19) { + return internal::ValueArray19(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19); +} + +template +internal::ValueArray20 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20) { + return internal::ValueArray20(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20); +} + +template +internal::ValueArray21 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21) { + return internal::ValueArray21(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21); +} + +template +internal::ValueArray22 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22) { + return internal::ValueArray22(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22); +} + +template +internal::ValueArray23 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23) { + return internal::ValueArray23(v1, v2, v3, + v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23); +} + +template +internal::ValueArray24 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24) { + return internal::ValueArray24(v1, v2, + v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, + v19, v20, v21, v22, v23, v24); +} + +template +internal::ValueArray25 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, + T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, + T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25) { + return internal::ValueArray25(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, + v18, v19, v20, v21, v22, v23, v24, v25); +} + +template +internal::ValueArray26 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26) { + return internal::ValueArray26(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26); +} + +template +internal::ValueArray27 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27) { + return internal::ValueArray27(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, + v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27); +} + +template +internal::ValueArray28 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28) { + return internal::ValueArray28(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, + v28); +} + +template +internal::ValueArray29 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29) { + return internal::ValueArray29(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, + v27, v28, v29); +} + +template +internal::ValueArray30 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30) { + return internal::ValueArray30(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, + v26, v27, v28, v29, v30); +} + +template +internal::ValueArray31 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31) { + return internal::ValueArray31(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, + v25, v26, v27, v28, v29, v30, v31); +} + +template +internal::ValueArray32 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32) { + return internal::ValueArray32(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32); +} + +template +internal::ValueArray33 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33) { + return internal::ValueArray33(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33); +} + +template +internal::ValueArray34 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, + T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, + T31 v31, T32 v32, T33 v33, T34 v34) { + return internal::ValueArray34(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, + v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34); +} + +template +internal::ValueArray35 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35) { + return internal::ValueArray35(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, + v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35); +} + +template +internal::ValueArray36 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36) { + return internal::ValueArray36(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36); +} + +template +internal::ValueArray37 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37) { + return internal::ValueArray37(v1, v2, v3, + v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36, v37); +} + +template +internal::ValueArray38 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37, T38 v38) { + return internal::ValueArray38(v1, v2, + v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, + v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, + v33, v34, v35, v36, v37, v38); +} + +template +internal::ValueArray39 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37, T38 v38, T39 v39) { + return internal::ValueArray39(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, + v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39); +} + +template +internal::ValueArray40 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, + T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, + T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, + T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, + T36 v36, T37 v37, T38 v38, T39 v39, T40 v40) { + return internal::ValueArray40(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40); +} + +template +internal::ValueArray41 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41) { + return internal::ValueArray41(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, + v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, + v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41); +} + +template +internal::ValueArray42 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42) { + return internal::ValueArray42(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, + v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, + v42); +} + +template +internal::ValueArray43 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43) { + return internal::ValueArray43(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, + v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, + v41, v42, v43); +} + +template +internal::ValueArray44 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44) { + return internal::ValueArray44(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, + v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, + v40, v41, v42, v43, v44); +} + +template +internal::ValueArray45 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, + T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, + T41 v41, T42 v42, T43 v43, T44 v44, T45 v45) { + return internal::ValueArray45(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, + v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, + v39, v40, v41, v42, v43, v44, v45); +} + +template +internal::ValueArray46 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46) { + return internal::ValueArray46(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, + v38, v39, v40, v41, v42, v43, v44, v45, v46); +} + +template +internal::ValueArray47 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47) { + return internal::ValueArray47(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, + v38, v39, v40, v41, v42, v43, v44, v45, v46, v47); +} + +template +internal::ValueArray48 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, + T48 v48) { + return internal::ValueArray48(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, + v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, + v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48); +} + +template +internal::ValueArray49 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, + T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, + T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, + T39 v39, T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, + T47 v47, T48 v48, T49 v49) { + return internal::ValueArray49(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, + v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, + v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49); +} + +template +internal::ValueArray50 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, + T38 v38, T39 v39, T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, + T46 v46, T47 v47, T48 v48, T49 v49, T50 v50) { + return internal::ValueArray50(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50); +} + +// Bool() allows generating tests with parameters in a set of (false, true). +// +// Synopsis: +// Bool() +// - returns a generator producing sequences with elements {false, true}. +// +// It is useful when testing code that depends on Boolean flags. Combinations +// of multiple flags can be tested when several Bool()'s are combined using +// Combine() function. +// +// In the following example all tests in the test case FlagDependentTest +// will be instantiated twice with parameters false and true. +// +// class FlagDependentTest : public testing::TestWithParam { +// virtual void SetUp() { +// external_flag = GetParam(); +// } +// } +// INSTANTIATE_TEST_CASE_P(BoolSequence, FlagDependentTest, Bool()); +// +inline internal::ParamGenerator Bool() { + return Values(false, true); +} + +# if GTEST_HAS_COMBINE +// Combine() allows the user to combine two or more sequences to produce +// values of a Cartesian product of those sequences' elements. +// +// Synopsis: +// Combine(gen1, gen2, ..., genN) +// - returns a generator producing sequences with elements coming from +// the Cartesian product of elements from the sequences generated by +// gen1, gen2, ..., genN. The sequence elements will have a type of +// tuple where T1, T2, ..., TN are the types +// of elements from sequences produces by gen1, gen2, ..., genN. +// +// Combine can have up to 10 arguments. This number is currently limited +// by the maximum number of elements in the tuple implementation used by Google +// Test. +// +// Example: +// +// This will instantiate tests in test case AnimalTest each one with +// the parameter values tuple("cat", BLACK), tuple("cat", WHITE), +// tuple("dog", BLACK), and tuple("dog", WHITE): +// +// enum Color { BLACK, GRAY, WHITE }; +// class AnimalTest +// : public testing::TestWithParam > {...}; +// +// TEST_P(AnimalTest, AnimalLooksNice) {...} +// +// INSTANTIATE_TEST_CASE_P(AnimalVariations, AnimalTest, +// Combine(Values("cat", "dog"), +// Values(BLACK, WHITE))); +// +// This will instantiate tests in FlagDependentTest with all variations of two +// Boolean flags: +// +// class FlagDependentTest +// : public testing::TestWithParam > { +// virtual void SetUp() { +// // Assigns external_flag_1 and external_flag_2 values from the tuple. +// tie(external_flag_1, external_flag_2) = GetParam(); +// } +// }; +// +// TEST_P(FlagDependentTest, TestFeature1) { +// // Test your code using external_flag_1 and external_flag_2 here. +// } +// INSTANTIATE_TEST_CASE_P(TwoBoolSequence, FlagDependentTest, +// Combine(Bool(), Bool())); +// +template +internal::CartesianProductHolder2 Combine( + const Generator1& g1, const Generator2& g2) { + return internal::CartesianProductHolder2( + g1, g2); +} + +template +internal::CartesianProductHolder3 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3) { + return internal::CartesianProductHolder3( + g1, g2, g3); +} + +template +internal::CartesianProductHolder4 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4) { + return internal::CartesianProductHolder4( + g1, g2, g3, g4); +} + +template +internal::CartesianProductHolder5 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5) { + return internal::CartesianProductHolder5( + g1, g2, g3, g4, g5); +} + +template +internal::CartesianProductHolder6 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6) { + return internal::CartesianProductHolder6( + g1, g2, g3, g4, g5, g6); +} + +template +internal::CartesianProductHolder7 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7) { + return internal::CartesianProductHolder7( + g1, g2, g3, g4, g5, g6, g7); +} + +template +internal::CartesianProductHolder8 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8) { + return internal::CartesianProductHolder8( + g1, g2, g3, g4, g5, g6, g7, g8); +} + +template +internal::CartesianProductHolder9 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8, const Generator9& g9) { + return internal::CartesianProductHolder9( + g1, g2, g3, g4, g5, g6, g7, g8, g9); +} + +template +internal::CartesianProductHolder10 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8, const Generator9& g9, + const Generator10& g10) { + return internal::CartesianProductHolder10( + g1, g2, g3, g4, g5, g6, g7, g8, g9, g10); +} +# endif // GTEST_HAS_COMBINE + + + +# define TEST_P(test_case_name, test_name) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public test_case_name { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + virtual void TestBody(); \ + private: \ + static int AddToRegistry() { \ + ::testing::UnitTest::GetInstance()->parameterized_test_registry(). \ + GetTestCasePatternHolder(\ + #test_case_name, __FILE__, __LINE__)->AddTestPattern(\ + #test_case_name, \ + #test_name, \ + new ::testing::internal::TestMetaFactory< \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)>()); \ + return 0; \ + } \ + static int gtest_registering_dummy_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)); \ + }; \ + int GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::gtest_registering_dummy_ = \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \ + void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +# define INSTANTIATE_TEST_CASE_P(prefix, test_case_name, generator) \ + ::testing::internal::ParamGenerator \ + gtest_##prefix##test_case_name##_EvalGenerator_() { return generator; } \ + int gtest_##prefix##test_case_name##_dummy_ = \ + ::testing::UnitTest::GetInstance()->parameterized_test_registry(). \ + GetTestCasePatternHolder(\ + #test_case_name, __FILE__, __LINE__)->AddTestCaseInstantiation(\ + #prefix, \ + >est_##prefix##test_case_name##_EvalGenerator_, \ + __FILE__, __LINE__) + +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// Google C++ Testing Framework definitions useful in production code. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PROD_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PROD_H_ + +// When you need to test the private or protected members of a class, +// use the FRIEND_TEST macro to declare your tests as friends of the +// class. For example: +// +// class MyClass { +// private: +// void MyMethod(); +// FRIEND_TEST(MyClassTest, MyMethod); +// }; +// +// class MyClassTest : public testing::Test { +// // ... +// }; +// +// TEST_F(MyClassTest, MyMethod) { +// // Can call MyClass::MyMethod() here. +// } + +#define FRIEND_TEST(test_case_name, test_name)\ +friend class test_case_name##_##test_name##_Test + +#endif // GTEST_INCLUDE_GTEST_GTEST_PROD_H_ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// + +#ifndef GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ +#define GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ + +#include +#include + +namespace testing { + +// A copyable object representing the result of a test part (i.e. an +// assertion or an explicit FAIL(), ADD_FAILURE(), or SUCCESS()). +// +// Don't inherit from TestPartResult as its destructor is not virtual. +class GTEST_API_ TestPartResult { + public: + // The possible outcomes of a test part (i.e. an assertion or an + // explicit SUCCEED(), FAIL(), or ADD_FAILURE()). + enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure // Failed and the test should be terminated. + }; + + // C'tor. TestPartResult does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestPartResult object. + TestPartResult(Type a_type, + const char* a_file_name, + int a_line_number, + const char* a_message) + : type_(a_type), + file_name_(a_file_name), + line_number_(a_line_number), + summary_(ExtractSummary(a_message)), + message_(a_message) { + } + + // Gets the outcome of the test part. + Type type() const { return type_; } + + // Gets the name of the source file where the test part took place, or + // NULL if it's unknown. + const char* file_name() const { return file_name_.c_str(); } + + // Gets the line in the source file where the test part took place, + // or -1 if it's unknown. + int line_number() const { return line_number_; } + + // Gets the summary of the failure message. + const char* summary() const { return summary_.c_str(); } + + // Gets the message associated with the test part. + const char* message() const { return message_.c_str(); } + + // Returns true iff the test part passed. + bool passed() const { return type_ == kSuccess; } + + // Returns true iff the test part failed. + bool failed() const { return type_ != kSuccess; } + + // Returns true iff the test part non-fatally failed. + bool nonfatally_failed() const { return type_ == kNonFatalFailure; } + + // Returns true iff the test part fatally failed. + bool fatally_failed() const { return type_ == kFatalFailure; } + private: + Type type_; + + // Gets the summary of the failure message by omitting the stack + // trace in it. + static internal::String ExtractSummary(const char* message); + + // The name of the source file where the test part took place, or + // NULL if the source file is unknown. + internal::String file_name_; + // The line in the source file where the test part took place, or -1 + // if the line number is unknown. + int line_number_; + internal::String summary_; // The test failure summary. + internal::String message_; // The test failure message. +}; + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result); + +// An array of TestPartResult objects. +// +// Don't inherit from TestPartResultArray as its destructor is not +// virtual. +class GTEST_API_ TestPartResultArray { + public: + TestPartResultArray() {} + + // Appends the given TestPartResult to the array. + void Append(const TestPartResult& result); + + // Returns the TestPartResult at the given index (0-based). + const TestPartResult& GetTestPartResult(int index) const; + + // Returns the number of TestPartResult objects in the array. + int size() const; + + private: + std::vector array_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestPartResultArray); +}; + +// This interface knows how to report a test part result. +class TestPartResultReporterInterface { + public: + virtual ~TestPartResultReporterInterface() {} + + virtual void ReportTestPartResult(const TestPartResult& result) = 0; +}; + +namespace internal { + +// This helper class is used by {ASSERT|EXPECT}_NO_FATAL_FAILURE to check if a +// statement generates new fatal failures. To do so it registers itself as the +// current test part result reporter. Besides checking if fatal failures were +// reported, it only delegates the reporting to the former result reporter. +// The original result reporter is restored in the destructor. +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +class GTEST_API_ HasNewFatalFailureHelper + : public TestPartResultReporterInterface { + public: + HasNewFatalFailureHelper(); + virtual ~HasNewFatalFailureHelper(); + virtual void ReportTestPartResult(const TestPartResult& result); + bool has_new_fatal_failure() const { return has_new_fatal_failure_; } + private: + bool has_new_fatal_failure_; + TestPartResultReporterInterface* original_reporter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(HasNewFatalFailureHelper); +}; + +} // namespace internal + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +#ifndef GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ + +// This header implements typed tests and type-parameterized tests. + +// Typed (aka type-driven) tests repeat the same test for types in a +// list. You must know which types you want to test with when writing +// typed tests. Here's how you do it: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + public: + ... + typedef std::list List; + static T shared_; + T value_; +}; + +// Next, associate a list of types with the test case, which will be +// repeated for each type in the list. The typedef is necessary for +// the macro to parse correctly. +typedef testing::Types MyTypes; +TYPED_TEST_CASE(FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// TYPED_TEST_CASE(FooTest, int); + +// Then, use TYPED_TEST() instead of TEST_F() to define as many typed +// tests for this test case as you want. +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + // Since we are inside a derived class template, C++ requires use to + // visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the TestFixture:: + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the "typename + // TestFixture::" prefix. + typename TestFixture::List values; + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } + +#endif // 0 + +// Type-parameterized tests are abstract test patterns parameterized +// by a type. Compared with typed tests, type-parameterized tests +// allow you to define the test pattern without knowing what the type +// parameters are. The defined pattern can be instantiated with +// different types any number of times, in any number of translation +// units. +// +// If you are designing an interface or concept, you can define a +// suite of type-parameterized tests to verify properties that any +// valid implementation of the interface/concept should have. Then, +// each implementation can easily instantiate the test suite to verify +// that it conforms to the requirements, without having to write +// similar tests repeatedly. Here's an example: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + ... +}; + +// Next, declare that you will define a type-parameterized test case +// (the _P suffix is for "parameterized" or "pattern", whichever you +// prefer): +TYPED_TEST_CASE_P(FooTest); + +// Then, use TYPED_TEST_P() to define as many type-parameterized tests +// for this type-parameterized test case as you want. +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } + +// Now the tricky part: you need to register all test patterns before +// you can instantiate them. The first argument of the macro is the +// test case name; the rest are the names of the tests in this test +// case. +REGISTER_TYPED_TEST_CASE_P(FooTest, + DoesBlah, HasPropertyA); + +// Finally, you are free to instantiate the pattern with the types you +// want. If you put the above code in a header file, you can #include +// it in multiple C++ source files and instantiate it multiple times. +// +// To distinguish different instances of the pattern, the first +// argument to the INSTANTIATE_* macro is a prefix that will be added +// to the actual test case name. Remember to pick unique prefixes for +// different instances. +typedef testing::Types MyTypes; +INSTANTIATE_TYPED_TEST_CASE_P(My, FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// INSTANTIATE_TYPED_TEST_CASE_P(My, FooTest, int); + +#endif // 0 + + +// Implements typed tests. + +#if GTEST_HAS_TYPED_TEST + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the typedef for the type parameters of the +// given test case. +# define GTEST_TYPE_PARAMS_(TestCaseName) gtest_type_params_##TestCaseName##_ + +// The 'Types' template argument below must have spaces around it +// since some compilers may choke on '>>' when passing a template +// instance (e.g. Types) +# define TYPED_TEST_CASE(CaseName, Types) \ + typedef ::testing::internal::TypeList< Types >::type \ + GTEST_TYPE_PARAMS_(CaseName) + +# define TYPED_TEST(CaseName, TestName) \ + template \ + class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ + : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTest< \ + CaseName, \ + ::testing::internal::TemplateSel< \ + GTEST_TEST_CLASS_NAME_(CaseName, TestName)>, \ + GTEST_TYPE_PARAMS_(CaseName)>::Register(\ + "", #CaseName, #TestName, 0); \ + template \ + void GTEST_TEST_CLASS_NAME_(CaseName, TestName)::TestBody() + +#endif // GTEST_HAS_TYPED_TEST + +// Implements type-parameterized tests. + +#if GTEST_HAS_TYPED_TEST_P + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the namespace name that the type-parameterized tests for +// the given type-parameterized test case are defined in. The exact +// name of the namespace is subject to change without notice. +# define GTEST_CASE_NAMESPACE_(TestCaseName) \ + gtest_case_##TestCaseName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the variable used to remember the names of +// the defined tests in the given test case. +# define GTEST_TYPED_TEST_CASE_P_STATE_(TestCaseName) \ + gtest_typed_test_case_p_state_##TestCaseName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE DIRECTLY. +// +// Expands to the name of the variable used to remember the names of +// the registered tests in the given test case. +# define GTEST_REGISTERED_TEST_NAMES_(TestCaseName) \ + gtest_registered_test_names_##TestCaseName##_ + +// The variables defined in the type-parameterized test macros are +// static as typically these macros are used in a .h file that can be +// #included in multiple translation units linked together. +# define TYPED_TEST_CASE_P(CaseName) \ + static ::testing::internal::TypedTestCasePState \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName) + +# define TYPED_TEST_P(CaseName, TestName) \ + namespace GTEST_CASE_NAMESPACE_(CaseName) { \ + template \ + class TestName : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + static bool gtest_##TestName##_defined_ GTEST_ATTRIBUTE_UNUSED_ = \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName).AddTestName(\ + __FILE__, __LINE__, #CaseName, #TestName); \ + } \ + template \ + void GTEST_CASE_NAMESPACE_(CaseName)::TestName::TestBody() + +# define REGISTER_TYPED_TEST_CASE_P(CaseName, ...) \ + namespace GTEST_CASE_NAMESPACE_(CaseName) { \ + typedef ::testing::internal::Templates<__VA_ARGS__>::type gtest_AllTests_; \ + } \ + static const char* const GTEST_REGISTERED_TEST_NAMES_(CaseName) = \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName).VerifyRegisteredTestNames(\ + __FILE__, __LINE__, #__VA_ARGS__) + +// The 'Types' template argument below must have spaces around it +// since some compilers may choke on '>>' when passing a template +// instance (e.g. Types) +# define INSTANTIATE_TYPED_TEST_CASE_P(Prefix, CaseName, Types) \ + bool gtest_##Prefix##_##CaseName GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTestCase::type>::Register(\ + #Prefix, #CaseName, GTEST_REGISTERED_TEST_NAMES_(CaseName)) + +#endif // GTEST_HAS_TYPED_TEST_P + +#endif // GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ + +// Depending on the platform, different string classes are available. +// On Linux, in addition to ::std::string, Google also makes use of +// class ::string, which has the same interface as ::std::string, but +// has a different implementation. +// +// The user can define GTEST_HAS_GLOBAL_STRING to 1 to indicate that +// ::string is available AND is a distinct type to ::std::string, or +// define it to 0 to indicate otherwise. +// +// If the user's ::std::string and ::string are the same class due to +// aliasing, he should define GTEST_HAS_GLOBAL_STRING to 0. +// +// If the user doesn't define GTEST_HAS_GLOBAL_STRING, it is defined +// heuristically. + +namespace testing { + +// Declares the flags. + +// This flag temporary enables the disabled tests. +GTEST_DECLARE_bool_(also_run_disabled_tests); + +// This flag brings the debugger on an assertion failure. +GTEST_DECLARE_bool_(break_on_failure); + +// This flag controls whether Google Test catches all test-thrown exceptions +// and logs them as failures. +GTEST_DECLARE_bool_(catch_exceptions); + +// This flag enables using colors in terminal output. Available values are +// "yes" to enable colors, "no" (disable colors), or "auto" (the default) +// to let Google Test decide. +GTEST_DECLARE_string_(color); + +// This flag sets up the filter to select by name using a glob pattern +// the tests to run. If the filter is not given all tests are executed. +GTEST_DECLARE_string_(filter); + +// This flag causes the Google Test to list tests. None of the tests listed +// are actually run if the flag is provided. +GTEST_DECLARE_bool_(list_tests); + +// This flag controls whether Google Test emits a detailed XML report to a file +// in addition to its normal textual output. +GTEST_DECLARE_string_(output); + +// This flags control whether Google Test prints the elapsed time for each +// test. +GTEST_DECLARE_bool_(print_time); + +// This flag specifies the random number seed. +GTEST_DECLARE_int32_(random_seed); + +// This flag sets how many times the tests are repeated. The default value +// is 1. If the value is -1 the tests are repeating forever. +GTEST_DECLARE_int32_(repeat); + +// This flag controls whether Google Test includes Google Test internal +// stack frames in failure stack traces. +GTEST_DECLARE_bool_(show_internal_stack_frames); + +// When this flag is specified, tests' order is randomized on every iteration. +GTEST_DECLARE_bool_(shuffle); + +// This flag specifies the maximum number of stack frames to be +// printed in a failure message. +GTEST_DECLARE_int32_(stack_trace_depth); + +// When this flag is specified, a failed assertion will throw an +// exception if exceptions are enabled, or exit the program with a +// non-zero code otherwise. +GTEST_DECLARE_bool_(throw_on_failure); + +// When this flag is set with a "host:port" string, on supported +// platforms test results are streamed to the specified port on +// the specified host machine. +GTEST_DECLARE_string_(stream_result_to); + +// The upper limit for valid stack trace depths. +const int kMaxStackTraceDepth = 100; + +namespace internal { + +class AssertHelper; +class DefaultGlobalTestPartResultReporter; +class ExecDeathTest; +class NoExecDeathTest; +class FinalSuccessChecker; +class GTestFlagSaver; +class TestResultAccessor; +class TestEventListenersAccessor; +class TestEventRepeater; +class WindowsDeathTest; +class UnitTestImpl* GetUnitTestImpl(); +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const String& message); + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". +// Declared in gtest-internal.h but defined here, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable) { + return (Message() << streamable).GetString(); +} + +} // namespace internal + +// The friend relationship of some of these classes is cyclic. +// If we don't forward declare them the compiler might confuse the classes +// in friendship clauses with same named classes on the scope. +class Test; +class TestCase; +class TestInfo; +class UnitTest; + +// A class for indicating whether an assertion was successful. When +// the assertion wasn't successful, the AssertionResult object +// remembers a non-empty message that describes how it failed. +// +// To create an instance of this class, use one of the factory functions +// (AssertionSuccess() and AssertionFailure()). +// +// This class is useful for two purposes: +// 1. Defining predicate functions to be used with Boolean test assertions +// EXPECT_TRUE/EXPECT_FALSE and their ASSERT_ counterparts +// 2. Defining predicate-format functions to be +// used with predicate assertions (ASSERT_PRED_FORMAT*, etc). +// +// For example, if you define IsEven predicate: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then the failed expectation EXPECT_TRUE(IsEven(Fib(5))) +// will print the message +// +// Value of: IsEven(Fib(5)) +// Actual: false (5 is odd) +// Expected: true +// +// instead of a more opaque +// +// Value of: IsEven(Fib(5)) +// Actual: false +// Expected: true +// +// in case IsEven is a simple Boolean predicate. +// +// If you expect your predicate to be reused and want to support informative +// messages in EXPECT_FALSE and ASSERT_FALSE (negative assertions show up +// about half as often as positive ones in our tests), supply messages for +// both success and failure cases: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess() << n << " is even"; +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then a statement EXPECT_FALSE(IsEven(Fib(6))) will print +// +// Value of: IsEven(Fib(6)) +// Actual: true (8 is even) +// Expected: false +// +// NB: Predicates that support negative Boolean assertions have reduced +// performance in positive ones so be careful not to use them in tests +// that have lots (tens of thousands) of positive Boolean assertions. +// +// To use this class with EXPECT_PRED_FORMAT assertions such as: +// +// // Verifies that Foo() returns an even number. +// EXPECT_PRED_FORMAT1(IsEven, Foo()); +// +// you need to define: +// +// testing::AssertionResult IsEven(const char* expr, int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() +// << "Expected: " << expr << " is even\n Actual: it's " << n; +// } +// +// If Foo() returns 5, you will see the following message: +// +// Expected: Foo() is even +// Actual: it's 5 +// +class GTEST_API_ AssertionResult { + public: + // Copy constructor. + // Used in EXPECT_TRUE/FALSE(assertion_result). + AssertionResult(const AssertionResult& other); + // Used in the EXPECT_TRUE/FALSE(bool_expression). + explicit AssertionResult(bool success) : success_(success) {} + + // Returns true iff the assertion succeeded. + operator bool() const { return success_; } // NOLINT + + // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. + AssertionResult operator!() const; + + // Returns the text streamed into this AssertionResult. Test assertions + // use it when they fail (i.e., the predicate's outcome doesn't match the + // assertion's expectation). When nothing has been streamed into the + // object, returns an empty string. + const char* message() const { + return message_.get() != NULL ? message_->c_str() : ""; + } + // TODO(vladl@google.com): Remove this after making sure no clients use it. + // Deprecated; please use message() instead. + const char* failure_message() const { return message(); } + + // Streams a custom failure message into this object. + template AssertionResult& operator<<(const T& value) { + AppendMessage(Message() << value); + return *this; + } + + // Allows streaming basic output manipulators such as endl or flush into + // this object. + AssertionResult& operator<<( + ::std::ostream& (*basic_manipulator)(::std::ostream& stream)) { + AppendMessage(Message() << basic_manipulator); + return *this; + } + + private: + // Appends the contents of message to message_. + void AppendMessage(const Message& a_message) { + if (message_.get() == NULL) + message_.reset(new ::std::string); + message_->append(a_message.GetString().c_str()); + } + + // Stores result of the assertion predicate. + bool success_; + // Stores the message describing the condition in case the expectation + // construct is not satisfied with the predicate's outcome. + // Referenced via a pointer to avoid taking too much stack frame space + // with test assertions. + internal::scoped_ptr< ::std::string> message_; + + GTEST_DISALLOW_ASSIGN_(AssertionResult); +}; + +// Makes a successful assertion result. +GTEST_API_ AssertionResult AssertionSuccess(); + +// Makes a failed assertion result. +GTEST_API_ AssertionResult AssertionFailure(); + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << msg. +GTEST_API_ AssertionResult AssertionFailure(const Message& msg); + +// The abstract class that all tests inherit from. +// +// In Google Test, a unit test program contains one or many TestCases, and +// each TestCase contains one or many Tests. +// +// When you define a test using the TEST macro, you don't need to +// explicitly derive from Test - the TEST macro automatically does +// this for you. +// +// The only time you derive from Test is when defining a test fixture +// to be used a TEST_F. For example: +// +// class FooTest : public testing::Test { +// protected: +// virtual void SetUp() { ... } +// virtual void TearDown() { ... } +// ... +// }; +// +// TEST_F(FooTest, Bar) { ... } +// TEST_F(FooTest, Baz) { ... } +// +// Test is not copyable. +class GTEST_API_ Test { + public: + friend class TestInfo; + + // Defines types for pointers to functions that set up and tear down + // a test case. + typedef internal::SetUpTestCaseFunc SetUpTestCaseFunc; + typedef internal::TearDownTestCaseFunc TearDownTestCaseFunc; + + // The d'tor is virtual as we intend to inherit from Test. + virtual ~Test(); + + // Sets up the stuff shared by all tests in this test case. + // + // Google Test will call Foo::SetUpTestCase() before running the first + // test in test case Foo. Hence a sub-class can define its own + // SetUpTestCase() method to shadow the one defined in the super + // class. + static void SetUpTestCase() {} + + // Tears down the stuff shared by all tests in this test case. + // + // Google Test will call Foo::TearDownTestCase() after running the last + // test in test case Foo. Hence a sub-class can define its own + // TearDownTestCase() method to shadow the one defined in the super + // class. + static void TearDownTestCase() {} + + // Returns true iff the current test has a fatal failure. + static bool HasFatalFailure(); + + // Returns true iff the current test has a non-fatal failure. + static bool HasNonfatalFailure(); + + // Returns true iff the current test has a (either fatal or + // non-fatal) failure. + static bool HasFailure() { return HasFatalFailure() || HasNonfatalFailure(); } + + // Logs a property for the current test. Only the last value for a given + // key is remembered. + // These are public static so they can be called from utility functions + // that are not members of the test fixture. + // The arguments are const char* instead strings, as Google Test is used + // on platforms where string doesn't compile. + // + // Note that a driving consideration for these RecordProperty methods + // was to produce xml output suited to the Greenspan charting utility, + // which at present will only chart values that fit in a 32-bit int. It + // is the user's responsibility to restrict their values to 32-bit ints + // if they intend them to be used with Greenspan. + static void RecordProperty(const char* key, const char* value); + static void RecordProperty(const char* key, int value); + + protected: + // Creates a Test object. + Test(); + + // Sets up the test fixture. + virtual void SetUp(); + + // Tears down the test fixture. + virtual void TearDown(); + + private: + // Returns true iff the current test has the same fixture class as + // the first test in the current test case. + static bool HasSameFixtureClass(); + + // Runs the test after the test fixture has been set up. + // + // A sub-class must implement this to define the test logic. + // + // DO NOT OVERRIDE THIS FUNCTION DIRECTLY IN A USER PROGRAM. + // Instead, use the TEST or TEST_F macro. + virtual void TestBody() = 0; + + // Sets up, executes, and tears down the test. + void Run(); + + // Deletes self. We deliberately pick an unusual name for this + // internal method to avoid clashing with names used in user TESTs. + void DeleteSelf_() { delete this; } + + // Uses a GTestFlagSaver to save and restore all Google Test flags. + const internal::GTestFlagSaver* const gtest_flag_saver_; + + // Often a user mis-spells SetUp() as Setup() and spends a long time + // wondering why it is never called by Google Test. The declaration of + // the following method is solely for catching such an error at + // compile time: + // + // - The return type is deliberately chosen to be not void, so it + // will be a conflict if a user declares void Setup() in his test + // fixture. + // + // - This method is private, so it will be another compiler error + // if a user calls it from his test fixture. + // + // DO NOT OVERRIDE THIS FUNCTION. + // + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return NULL; } + + // We disallow copying Tests. + GTEST_DISALLOW_COPY_AND_ASSIGN_(Test); +}; + +typedef internal::TimeInMillis TimeInMillis; + +// A copyable object representing a user specified test property which can be +// output as a key/value string pair. +// +// Don't inherit from TestProperty as its destructor is not virtual. +class TestProperty { + public: + // C'tor. TestProperty does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestProperty object. + TestProperty(const char* a_key, const char* a_value) : + key_(a_key), value_(a_value) { + } + + // Gets the user supplied key. + const char* key() const { + return key_.c_str(); + } + + // Gets the user supplied value. + const char* value() const { + return value_.c_str(); + } + + // Sets a new value, overriding the one supplied in the constructor. + void SetValue(const char* new_value) { + value_ = new_value; + } + + private: + // The key supplied by the user. + internal::String key_; + // The value supplied by the user. + internal::String value_; +}; + +// The result of a single Test. This includes a list of +// TestPartResults, a list of TestProperties, a count of how many +// death tests there are in the Test, and how much time it took to run +// the Test. +// +// TestResult is not copyable. +class GTEST_API_ TestResult { + public: + // Creates an empty TestResult. + TestResult(); + + // D'tor. Do not inherit from TestResult. + ~TestResult(); + + // Gets the number of all test parts. This is the sum of the number + // of successful test parts and the number of failed test parts. + int total_part_count() const; + + // Returns the number of the test properties. + int test_property_count() const; + + // Returns true iff the test passed (i.e. no test part failed). + bool Passed() const { return !Failed(); } + + // Returns true iff the test failed. + bool Failed() const; + + // Returns true iff the test fatally failed. + bool HasFatalFailure() const; + + // Returns true iff the test has a non-fatal failure. + bool HasNonfatalFailure() const; + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns the i-th test part result among all the results. i can range + // from 0 to test_property_count() - 1. If i is not in that range, aborts + // the program. + const TestPartResult& GetTestPartResult(int i) const; + + // Returns the i-th test property. i can range from 0 to + // test_property_count() - 1. If i is not in that range, aborts the + // program. + const TestProperty& GetTestProperty(int i) const; + + private: + friend class TestInfo; + friend class UnitTest; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::ExecDeathTest; + friend class internal::TestResultAccessor; + friend class internal::UnitTestImpl; + friend class internal::WindowsDeathTest; + + // Gets the vector of TestPartResults. + const std::vector& test_part_results() const { + return test_part_results_; + } + + // Gets the vector of TestProperties. + const std::vector& test_properties() const { + return test_properties_; + } + + // Sets the elapsed time. + void set_elapsed_time(TimeInMillis elapsed) { elapsed_time_ = elapsed; } + + // Adds a test property to the list. The property is validated and may add + // a non-fatal failure if invalid (e.g., if it conflicts with reserved + // key names). If a property is already recorded for the same key, the + // value will be updated, rather than storing multiple values for the same + // key. + void RecordProperty(const TestProperty& test_property); + + // Adds a failure if the key is a reserved attribute of Google Test + // testcase tags. Returns true if the property is valid. + // TODO(russr): Validate attribute names are legal and human readable. + static bool ValidateTestProperty(const TestProperty& test_property); + + // Adds a test part result to the list. + void AddTestPartResult(const TestPartResult& test_part_result); + + // Returns the death test count. + int death_test_count() const { return death_test_count_; } + + // Increments the death test count, returning the new count. + int increment_death_test_count() { return ++death_test_count_; } + + // Clears the test part results. + void ClearTestPartResults(); + + // Clears the object. + void Clear(); + + // Protects mutable state of the property vector and of owned + // properties, whose values may be updated. + internal::Mutex test_properites_mutex_; + + // The vector of TestPartResults + std::vector test_part_results_; + // The vector of TestProperties + std::vector test_properties_; + // Running count of death tests. + int death_test_count_; + // The elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + + // We disallow copying TestResult. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestResult); +}; // class TestResult + +// A TestInfo object stores the following information about a test: +// +// Test case name +// Test name +// Whether the test should be run +// A function pointer that creates the test object when invoked +// Test result +// +// The constructor of TestInfo registers itself with the UnitTest +// singleton such that the RUN_ALL_TESTS() macro knows which tests to +// run. +class GTEST_API_ TestInfo { + public: + // Destructs a TestInfo object. This function is not virtual, so + // don't inherit from TestInfo. + ~TestInfo(); + + // Returns the test case name. + const char* test_case_name() const { return test_case_name_.c_str(); } + + // Returns the test name. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a typed + // or a type-parameterized test. + const char* type_param() const { + if (type_param_.get() != NULL) + return type_param_->c_str(); + return NULL; + } + + // Returns the text representation of the value parameter, or NULL if this + // is not a value-parameterized test. + const char* value_param() const { + if (value_param_.get() != NULL) + return value_param_->c_str(); + return NULL; + } + + // Returns true if this test should run, that is if the test is not disabled + // (or it is disabled but the also_run_disabled_tests flag has been specified) + // and its full name matches the user-specified filter. + // + // Google Test allows the user to filter the tests by their full names. + // The full name of a test Bar in test case Foo is defined as + // "Foo.Bar". Only the tests that match the filter will run. + // + // A filter is a colon-separated list of glob (not regex) patterns, + // optionally followed by a '-' and a colon-separated list of + // negative patterns (tests to exclude). A test is run if it + // matches one of the positive patterns and does not match any of + // the negative patterns. + // + // For example, *A*:Foo.* is a filter that matches any string that + // contains the character 'A' or starts with "Foo.". + bool should_run() const { return should_run_; } + + // Returns the result of the test. + const TestResult* result() const { return &result_; } + + private: + +#if GTEST_HAS_DEATH_TEST + friend class internal::DefaultDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + friend class Test; + friend class TestCase; + friend class internal::UnitTestImpl; + friend TestInfo* internal::MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + internal::TypeId fixture_class_id, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc, + internal::TestFactoryBase* factory); + + // Constructs a TestInfo object. The newly constructed instance assumes + // ownership of the factory object. + TestInfo(const char* test_case_name, const char* name, + const char* a_type_param, + const char* a_value_param, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory); + + // Increments the number of death tests encountered in this test so + // far. + int increment_death_test_count() { + return result_.increment_death_test_count(); + } + + // Creates the test object, runs it, records its result, and then + // deletes it. + void Run(); + + static void ClearTestResult(TestInfo* test_info) { + test_info->result_.Clear(); + } + + // These fields are immutable properties of the test. + const std::string test_case_name_; // Test case name + const std::string name_; // Test name + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const internal::scoped_ptr type_param_; + // Text representation of the value parameter, or NULL if this is not a + // value-parameterized test. + const internal::scoped_ptr value_param_; + const internal::TypeId fixture_class_id_; // ID of the test fixture class + bool should_run_; // True iff this test should run + bool is_disabled_; // True iff this test is disabled + bool matches_filter_; // True if this test matches the + // user-specified filter. + internal::TestFactoryBase* const factory_; // The factory that creates + // the test object + + // This field is mutable and needs to be reset before running the + // test for the second time. + TestResult result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestInfo); +}; + +// A test case, which consists of a vector of TestInfos. +// +// TestCase is not copyable. +class GTEST_API_ TestCase { + public: + // Creates a TestCase with the given name. + // + // TestCase does NOT have a default constructor. Always use this + // constructor to create a TestCase object. + // + // Arguments: + // + // name: name of the test case + // a_type_param: the name of the test's type parameter, or NULL if + // this is not a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + TestCase(const char* name, const char* a_type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc); + + // Destructor of TestCase. + virtual ~TestCase(); + + // Gets the name of the TestCase. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a + // type-parameterized test case. + const char* type_param() const { + if (type_param_.get() != NULL) + return type_param_->c_str(); + return NULL; + } + + // Returns true if any test in this test case should run. + bool should_run() const { return should_run_; } + + // Gets the number of successful tests in this test case. + int successful_test_count() const; + + // Gets the number of failed tests in this test case. + int failed_test_count() const; + + // Gets the number of disabled tests in this test case. + int disabled_test_count() const; + + // Get the number of tests in this test case that should run. + int test_to_run_count() const; + + // Gets the number of all tests in this test case. + int total_test_count() const; + + // Returns true iff the test case passed. + bool Passed() const { return !Failed(); } + + // Returns true iff the test case failed. + bool Failed() const { return failed_test_count() > 0; } + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + const TestInfo* GetTestInfo(int i) const; + + private: + friend class Test; + friend class internal::UnitTestImpl; + + // Gets the (mutable) vector of TestInfos in this TestCase. + std::vector& test_info_list() { return test_info_list_; } + + // Gets the (immutable) vector of TestInfos in this TestCase. + const std::vector& test_info_list() const { + return test_info_list_; + } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + TestInfo* GetMutableTestInfo(int i); + + // Sets the should_run member. + void set_should_run(bool should) { should_run_ = should; } + + // Adds a TestInfo to this test case. Will delete the TestInfo upon + // destruction of the TestCase object. + void AddTestInfo(TestInfo * test_info); + + // Clears the results of all tests in this test case. + void ClearResult(); + + // Clears the results of all tests in the given test case. + static void ClearTestCaseResult(TestCase* test_case) { + test_case->ClearResult(); + } + + // Runs every test in this TestCase. + void Run(); + + // Runs SetUpTestCase() for this TestCase. This wrapper is needed + // for catching exceptions thrown from SetUpTestCase(). + void RunSetUpTestCase() { (*set_up_tc_)(); } + + // Runs TearDownTestCase() for this TestCase. This wrapper is + // needed for catching exceptions thrown from TearDownTestCase(). + void RunTearDownTestCase() { (*tear_down_tc_)(); } + + // Returns true iff test passed. + static bool TestPassed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Passed(); + } + + // Returns true iff test failed. + static bool TestFailed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Failed(); + } + + // Returns true iff test is disabled. + static bool TestDisabled(const TestInfo* test_info) { + return test_info->is_disabled_; + } + + // Returns true if the given test should run. + static bool ShouldRunTest(const TestInfo* test_info) { + return test_info->should_run(); + } + + // Shuffles the tests in this test case. + void ShuffleTests(internal::Random* random); + + // Restores the test order to before the first shuffle. + void UnshuffleTests(); + + // Name of the test case. + internal::String name_; + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const internal::scoped_ptr type_param_; + // The vector of TestInfos in their original order. It owns the + // elements in the vector. + std::vector test_info_list_; + // Provides a level of indirection for the test list to allow easy + // shuffling and restoring the test order. The i-th element in this + // vector is the index of the i-th test in the shuffled test list. + std::vector test_indices_; + // Pointer to the function that sets up the test case. + Test::SetUpTestCaseFunc set_up_tc_; + // Pointer to the function that tears down the test case. + Test::TearDownTestCaseFunc tear_down_tc_; + // True iff any test in this test case should run. + bool should_run_; + // Elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + + // We disallow copying TestCases. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestCase); +}; + +// An Environment object is capable of setting up and tearing down an +// environment. The user should subclass this to define his own +// environment(s). +// +// An Environment object does the set-up and tear-down in virtual +// methods SetUp() and TearDown() instead of the constructor and the +// destructor, as: +// +// 1. You cannot safely throw from a destructor. This is a problem +// as in some cases Google Test is used where exceptions are enabled, and +// we may want to implement ASSERT_* using exceptions where they are +// available. +// 2. You cannot use ASSERT_* directly in a constructor or +// destructor. +class Environment { + public: + // The d'tor is virtual as we need to subclass Environment. + virtual ~Environment() {} + + // Override this to define how to set up the environment. + virtual void SetUp() {} + + // Override this to define how to tear down the environment. + virtual void TearDown() {} + private: + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return NULL; } +}; + +// The interface for tracing execution of tests. The methods are organized in +// the order the corresponding events are fired. +class TestEventListener { + public: + virtual ~TestEventListener() {} + + // Fired before any test activity starts. + virtual void OnTestProgramStart(const UnitTest& unit_test) = 0; + + // Fired before each iteration of tests starts. There may be more than + // one iteration if GTEST_FLAG(repeat) is set. iteration is the iteration + // index, starting from 0. + virtual void OnTestIterationStart(const UnitTest& unit_test, + int iteration) = 0; + + // Fired before environment set-up for each iteration of tests starts. + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test) = 0; + + // Fired after environment set-up for each iteration of tests ends. + virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) = 0; + + // Fired before the test case starts. + virtual void OnTestCaseStart(const TestCase& test_case) = 0; + + // Fired before the test starts. + virtual void OnTestStart(const TestInfo& test_info) = 0; + + // Fired after a failed assertion or a SUCCEED() invocation. + virtual void OnTestPartResult(const TestPartResult& test_part_result) = 0; + + // Fired after the test ends. + virtual void OnTestEnd(const TestInfo& test_info) = 0; + + // Fired after the test case ends. + virtual void OnTestCaseEnd(const TestCase& test_case) = 0; + + // Fired before environment tear-down for each iteration of tests starts. + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test) = 0; + + // Fired after environment tear-down for each iteration of tests ends. + virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) = 0; + + // Fired after each iteration of tests finishes. + virtual void OnTestIterationEnd(const UnitTest& unit_test, + int iteration) = 0; + + // Fired after all test activities have ended. + virtual void OnTestProgramEnd(const UnitTest& unit_test) = 0; +}; + +// The convenience class for users who need to override just one or two +// methods and are not concerned that a possible change to a signature of +// the methods they override will not be caught during the build. For +// comments about each method please see the definition of TestEventListener +// above. +class EmptyTestEventListener : public TestEventListener { + public: + virtual void OnTestProgramStart(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationStart(const UnitTest& /*unit_test*/, + int /*iteration*/) {} + virtual void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) {} + virtual void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestCaseStart(const TestCase& /*test_case*/) {} + virtual void OnTestStart(const TestInfo& /*test_info*/) {} + virtual void OnTestPartResult(const TestPartResult& /*test_part_result*/) {} + virtual void OnTestEnd(const TestInfo& /*test_info*/) {} + virtual void OnTestCaseEnd(const TestCase& /*test_case*/) {} + virtual void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) {} + virtual void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int /*iteration*/) {} + virtual void OnTestProgramEnd(const UnitTest& /*unit_test*/) {} +}; + +// TestEventListeners lets users add listeners to track events in Google Test. +class GTEST_API_ TestEventListeners { + public: + TestEventListeners(); + ~TestEventListeners(); + + // Appends an event listener to the end of the list. Google Test assumes + // the ownership of the listener (i.e. it will delete the listener when + // the test program finishes). + void Append(TestEventListener* listener); + + // Removes the given event listener from the list and returns it. It then + // becomes the caller's responsibility to delete the listener. Returns + // NULL if the listener is not found in the list. + TestEventListener* Release(TestEventListener* listener); + + // Returns the standard listener responsible for the default console + // output. Can be removed from the listeners list to shut down default + // console output. Note that removing this object from the listener list + // with Release transfers its ownership to the caller and makes this + // function return NULL the next time. + TestEventListener* default_result_printer() const { + return default_result_printer_; + } + + // Returns the standard listener responsible for the default XML output + // controlled by the --gtest_output=xml flag. Can be removed from the + // listeners list by users who want to shut down the default XML output + // controlled by this flag and substitute it with custom one. Note that + // removing this object from the listener list with Release transfers its + // ownership to the caller and makes this function return NULL the next + // time. + TestEventListener* default_xml_generator() const { + return default_xml_generator_; + } + + private: + friend class TestCase; + friend class TestInfo; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::NoExecDeathTest; + friend class internal::TestEventListenersAccessor; + friend class internal::UnitTestImpl; + + // Returns repeater that broadcasts the TestEventListener events to all + // subscribers. + TestEventListener* repeater(); + + // Sets the default_result_printer attribute to the provided listener. + // The listener is also added to the listener list and previous + // default_result_printer is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultResultPrinter(TestEventListener* listener); + + // Sets the default_xml_generator attribute to the provided listener. The + // listener is also added to the listener list and previous + // default_xml_generator is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultXmlGenerator(TestEventListener* listener); + + // Controls whether events will be forwarded by the repeater to the + // listeners in the list. + bool EventForwardingEnabled() const; + void SuppressEventForwarding(); + + // The actual list of listeners. + internal::TestEventRepeater* repeater_; + // Listener responsible for the standard result output. + TestEventListener* default_result_printer_; + // Listener responsible for the creation of the XML output file. + TestEventListener* default_xml_generator_; + + // We disallow copying TestEventListeners. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventListeners); +}; + +// A UnitTest consists of a vector of TestCases. +// +// This is a singleton class. The only instance of UnitTest is +// created when UnitTest::GetInstance() is first called. This +// instance is never deleted. +// +// UnitTest is not copyable. +// +// This class is thread-safe as long as the methods are called +// according to their specification. +class GTEST_API_ UnitTest { + public: + // Gets the singleton UnitTest object. The first time this method + // is called, a UnitTest object is constructed and returned. + // Consecutive calls will return the same object. + static UnitTest* GetInstance(); + + // Runs all tests in this UnitTest object and prints the result. + // Returns 0 if successful, or 1 otherwise. + // + // This method can only be called from the main thread. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + int Run() GTEST_MUST_USE_RESULT_; + + // Returns the working directory when the first TEST() or TEST_F() + // was executed. The UnitTest object owns the string. + const char* original_working_dir() const; + + // Returns the TestCase object for the test that's currently running, + // or NULL if no test is running. + const TestCase* current_test_case() const; + + // Returns the TestInfo object for the test that's currently running, + // or NULL if no test is running. + const TestInfo* current_test_info() const; + + // Returns the random seed used at the start of the current test run. + int random_seed() const; + +#if GTEST_HAS_PARAM_TEST + // Returns the ParameterizedTestCaseRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + internal::ParameterizedTestCaseRegistry& parameterized_test_registry(); +#endif // GTEST_HAS_PARAM_TEST + + // Gets the number of successful test cases. + int successful_test_case_count() const; + + // Gets the number of failed test cases. + int failed_test_case_count() const; + + // Gets the number of all test cases. + int total_test_case_count() const; + + // Gets the number of all test cases that contain at least one test + // that should run. + int test_case_to_run_count() const; + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const; + + // Returns true iff the unit test passed (i.e. all test cases passed). + bool Passed() const; + + // Returns true iff the unit test failed (i.e. some test case failed + // or something outside of all tests failed). + bool Failed() const; + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + const TestCase* GetTestCase(int i) const; + + // Returns the list of event listeners that can be used to track events + // inside Google Test. + TestEventListeners& listeners(); + + private: + // Registers and returns a global test environment. When a test + // program is run, all global test environments will be set-up in + // the order they were registered. After all tests in the program + // have finished, all global test environments will be torn-down in + // the *reverse* order they were registered. + // + // The UnitTest object takes ownership of the given environment. + // + // This method can only be called from the main thread. + Environment* AddEnvironment(Environment* env); + + // Adds a TestPartResult to the current TestResult object. All + // Google Test assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) + // eventually call this to report their results. The user code + // should use the assertion macros instead of calling this directly. + void AddTestPartResult(TestPartResult::Type result_type, + const char* file_name, + int line_number, + const internal::String& message, + const internal::String& os_stack_trace); + + // Adds a TestProperty to the current TestResult object. If the result already + // contains a property with the same key, the value will be updated. + void RecordPropertyForCurrentTest(const char* key, const char* value); + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + TestCase* GetMutableTestCase(int i); + + // Accessors for the implementation object. + internal::UnitTestImpl* impl() { return impl_; } + const internal::UnitTestImpl* impl() const { return impl_; } + + // These classes and funcions are friends as they need to access private + // members of UnitTest. + friend class Test; + friend class internal::AssertHelper; + friend class internal::ScopedTrace; + friend Environment* AddGlobalTestEnvironment(Environment* env); + friend internal::UnitTestImpl* internal::GetUnitTestImpl(); + friend void internal::ReportFailureInUnknownLocation( + TestPartResult::Type result_type, + const internal::String& message); + + // Creates an empty UnitTest. + UnitTest(); + + // D'tor + virtual ~UnitTest(); + + // Pushes a trace defined by SCOPED_TRACE() on to the per-thread + // Google Test trace stack. + void PushGTestTrace(const internal::TraceInfo& trace); + + // Pops a trace from the per-thread Google Test trace stack. + void PopGTestTrace(); + + // Protects mutable state in *impl_. This is mutable as some const + // methods need to lock it too. + mutable internal::Mutex mutex_; + + // Opaque implementation object. This field is never changed once + // the object is constructed. We don't mark it as const here, as + // doing so will cause a warning in the constructor of UnitTest. + // Mutable state in *impl_ is protected by mutex_. + internal::UnitTestImpl* impl_; + + // We disallow copying UnitTest. + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTest); +}; + +// A convenient wrapper for adding an environment for the test +// program. +// +// You should call this before RUN_ALL_TESTS() is called, probably in +// main(). If you use gtest_main, you need to call this before main() +// starts for it to take effect. For example, you can define a global +// variable like this: +// +// testing::Environment* const foo_env = +// testing::AddGlobalTestEnvironment(new FooEnvironment); +// +// However, we strongly recommend you to write your own main() and +// call AddGlobalTestEnvironment() there, as relying on initialization +// of global variables makes the code harder to read and may cause +// problems when you register multiple environments from different +// translation units and the environments have dependencies among them +// (remember that the compiler doesn't guarantee the order in which +// global variables from different translation units are initialized). +inline Environment* AddGlobalTestEnvironment(Environment* env) { + return UnitTest::GetInstance()->AddEnvironment(env); +} + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +GTEST_API_ void InitGoogleTest(int* argc, char** argv); + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +GTEST_API_ void InitGoogleTest(int* argc, wchar_t** argv); + +namespace internal { + +// Formats a comparison assertion (e.g. ASSERT_EQ, EXPECT_LT, and etc) +// operand to be used in a failure message. The type (but not value) +// of the other operand may affect the format. This allows us to +// print a char* as a raw pointer when it is compared against another +// char*, and print it as a C string when it is compared against an +// std::string object, for example. +// +// The default implementation ignores the type of the other operand. +// Some specialized versions are used to handle formatting wide or +// narrow C strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +String FormatForComparisonFailureMessage(const T1& value, + const T2& /* other_operand */) { + // C++Builder compiles this incorrectly if the namespace isn't explicitly + // given. + return ::testing::PrintToString(value); +} + +// The helper function for {ASSERT|EXPECT}_EQ. +template +AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual) { +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4389) // Temporarily disables warning on + // signed/unsigned mismatch. +#endif + + if (expected == actual) { + return AssertionSuccess(); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif + + return EqFailure(expected_expression, + actual_expression, + FormatForComparisonFailureMessage(expected, actual), + FormatForComparisonFailureMessage(actual, expected), + false); +} + +// With this overloaded version, we allow anonymous enums to be used +// in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous enums +// can be implicitly cast to BiggestInt. +GTEST_API_ AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual); + +// The helper class for {ASSERT|EXPECT}_EQ. The template argument +// lhs_is_null_literal is true iff the first argument to ASSERT_EQ() +// is a null pointer literal. The following default implementation is +// for lhs_is_null_literal being false. +template +class EqHelper { + public: + // This templatized version is for the general case. + template + static AssertionResult Compare(const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } + + // With this overloaded version, we allow anonymous enums to be used + // in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous + // enums can be implicitly cast to BiggestInt. + // + // Even though its body looks the same as the above version, we + // cannot merge the two, as it will make anonymous enums unhappy. + static AssertionResult Compare(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } +}; + +// This specialization is used when the first argument to ASSERT_EQ() +// is a null pointer literal, like NULL, false, or 0. +template <> +class EqHelper { + public: + // We define two overloaded versions of Compare(). The first + // version will be picked when the second argument to ASSERT_EQ() is + // NOT a pointer, e.g. ASSERT_EQ(0, AnIntFunction()) or + // EXPECT_EQ(false, a_bool). + template + static AssertionResult Compare( + const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual, + // The following line prevents this overload from being considered if T2 + // is not a pointer type. We need this because ASSERT_EQ(NULL, my_ptr) + // expands to Compare("", "", NULL, my_ptr), which requires a conversion + // to match the Secret* in the other overload, which would otherwise make + // this template match better. + typename EnableIf::value>::type* = 0) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } + + // This version will be picked when the second argument to ASSERT_EQ() is a + // pointer, e.g. ASSERT_EQ(NULL, a_pointer). + template + static AssertionResult Compare( + const char* expected_expression, + const char* actual_expression, + // We used to have a second template parameter instead of Secret*. That + // template parameter would deduce to 'long', making this a better match + // than the first overload even without the first overload's EnableIf. + // Unfortunately, gcc with -Wconversion-null warns when "passing NULL to + // non-pointer argument" (even a deduced integral argument), so the old + // implementation caused warnings in user code. + Secret* /* expected (NULL) */, + T* actual) { + // We already know that 'expected' is a null pointer. + return CmpHelperEQ(expected_expression, actual_expression, + static_cast(NULL), actual); + } +}; + +// A macro for implementing the helper functions needed to implement +// ASSERT_?? and EXPECT_??. It is here just to avoid copy-and-paste +// of similar code. +// +// For each templatized helper function, we also define an overloaded +// version for BiggestInt in order to reduce code bloat and allow +// anonymous enums to be used with {ASSERT|EXPECT}_?? when compiled +// with gcc 4. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +#define GTEST_IMPL_CMP_HELPER_(op_name, op)\ +template \ +AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ + const T1& val1, const T2& val2) {\ + if (val1 op val2) {\ + return AssertionSuccess();\ + } else {\ + return AssertionFailure() \ + << "Expected: (" << expr1 << ") " #op " (" << expr2\ + << "), actual: " << FormatForComparisonFailureMessage(val1, val2)\ + << " vs " << FormatForComparisonFailureMessage(val2, val1);\ + }\ +}\ +GTEST_API_ AssertionResult CmpHelper##op_name(\ + const char* expr1, const char* expr2, BiggestInt val1, BiggestInt val2) + +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + +// Implements the helper function for {ASSERT|EXPECT}_NE +GTEST_IMPL_CMP_HELPER_(NE, !=); +// Implements the helper function for {ASSERT|EXPECT}_LE +GTEST_IMPL_CMP_HELPER_(LE, <=); +// Implements the helper function for {ASSERT|EXPECT}_LT +GTEST_IMPL_CMP_HELPER_(LT, < ); +// Implements the helper function for {ASSERT|EXPECT}_GE +GTEST_IMPL_CMP_HELPER_(GE, >=); +// Implements the helper function for {ASSERT|EXPECT}_GT +GTEST_IMPL_CMP_HELPER_(GT, > ); + +#undef GTEST_IMPL_CMP_HELPER_ + +// The helper function for {ASSERT|EXPECT}_STREQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual); + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASEEQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual); + +// The helper function for {ASSERT|EXPECT}_STRNE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + + +// Helper function for *_STREQ on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const wchar_t* expected, + const wchar_t* actual); + +// Helper function for *_STRNE on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2); + +} // namespace internal + +// IsSubstring() and IsNotSubstring() are intended to be used as the +// first argument to {EXPECT,ASSERT}_PRED_FORMAT2(), not by +// themselves. They check whether needle is a substring of haystack +// (NULL is considered a substring of itself only), and return an +// appropriate error message when they fail. +// +// The {needle,haystack}_expr arguments are the stringified +// expressions that generated the two real arguments. +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); + +#if GTEST_HAS_STD_WSTRING +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +// Helper template function for comparing floating-points. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, + const char* actual_expression, + RawType expected, + RawType actual) { + const FloatingPoint lhs(expected), rhs(actual); + + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + ::std::stringstream expected_ss; + expected_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << expected; + + ::std::stringstream actual_ss; + actual_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << actual; + + return EqFailure(expected_expression, + actual_expression, + StringStreamToString(&expected_ss), + StringStreamToString(&actual_ss), + false); +} + +// Helper function for implementing ASSERT_NEAR. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error); + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// A class that enables one to stream messages to assertion macros +class GTEST_API_ AssertHelper { + public: + // Constructor. + AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message); + ~AssertHelper(); + + // Message assignment is a semantic trick to enable assertion + // streaming; see the GTEST_MESSAGE_ macro below. + void operator=(const Message& message) const; + + private: + // We put our data in a struct so that the size of the AssertHelper class can + // be as small as possible. This is important because gcc is incapable of + // re-using stack space even for temporary variables, so every EXPECT_EQ + // reserves stack space for another AssertHelper. + struct AssertHelperData { + AssertHelperData(TestPartResult::Type t, + const char* srcfile, + int line_num, + const char* msg) + : type(t), file(srcfile), line(line_num), message(msg) { } + + TestPartResult::Type const type; + const char* const file; + int const line; + String const message; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelperData); + }; + + AssertHelperData* const data_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelper); +}; + +} // namespace internal + +#if GTEST_HAS_PARAM_TEST +// The pure interface class that all value-parameterized tests inherit from. +// A value-parameterized class must inherit from both ::testing::Test and +// ::testing::WithParamInterface. In most cases that just means inheriting +// from ::testing::TestWithParam, but more complicated test hierarchies +// may need to inherit from Test and WithParamInterface at different levels. +// +// This interface has support for accessing the test parameter value via +// the GetParam() method. +// +// Use it with one of the parameter generator defining functions, like Range(), +// Values(), ValuesIn(), Bool(), and Combine(). +// +// class FooTest : public ::testing::TestWithParam { +// protected: +// FooTest() { +// // Can use GetParam() here. +// } +// virtual ~FooTest() { +// // Can use GetParam() here. +// } +// virtual void SetUp() { +// // Can use GetParam() here. +// } +// virtual void TearDown { +// // Can use GetParam() here. +// } +// }; +// TEST_P(FooTest, DoesBar) { +// // Can use GetParam() method here. +// Foo foo; +// ASSERT_TRUE(foo.DoesBar(GetParam())); +// } +// INSTANTIATE_TEST_CASE_P(OneToTenRange, FooTest, ::testing::Range(1, 10)); + +template +class WithParamInterface { + public: + typedef T ParamType; + virtual ~WithParamInterface() {} + + // The current parameter value. Is also available in the test fixture's + // constructor. This member function is non-static, even though it only + // references static data, to reduce the opportunity for incorrect uses + // like writing 'WithParamInterface::GetParam()' for a test that + // uses a fixture whose parameter type is int. + const ParamType& GetParam() const { return *parameter_; } + + private: + // Sets parameter value. The caller is responsible for making sure the value + // remains alive and unchanged throughout the current test. + static void SetParam(const ParamType* parameter) { + parameter_ = parameter; + } + + // Static value used for accessing parameter during a test lifetime. + static const ParamType* parameter_; + + // TestClass must be a subclass of WithParamInterface and Test. + template friend class internal::ParameterizedTestFactory; +}; + +template +const T* WithParamInterface::parameter_ = NULL; + +// Most value-parameterized classes can ignore the existence of +// WithParamInterface, and can just inherit from ::testing::TestWithParam. + +template +class TestWithParam : public Test, public WithParamInterface { +}; + +#endif // GTEST_HAS_PARAM_TEST + +// Macros for indicating success/failure in test code. + +// ADD_FAILURE unconditionally adds a failure to the current test. +// SUCCEED generates a success - it doesn't automatically make the +// current test successful, as a test is only successful when it has +// no failure. +// +// EXPECT_* verifies that a certain condition is satisfied. If not, +// it behaves like ADD_FAILURE. In particular: +// +// EXPECT_TRUE verifies that a Boolean condition is true. +// EXPECT_FALSE verifies that a Boolean condition is false. +// +// FAIL and ASSERT_* are similar to ADD_FAILURE and EXPECT_*, except +// that they will also abort the current function on failure. People +// usually want the fail-fast behavior of FAIL and ASSERT_*, but those +// writing data-driven tests often find themselves using ADD_FAILURE +// and EXPECT_* more. +// +// Examples: +// +// EXPECT_TRUE(server.StatusIsOK()); +// ASSERT_FALSE(server.HasPendingRequest(port)) +// << "There are still pending requests " << "on port " << port; + +// Generates a nonfatal failure with a generic message. +#define ADD_FAILURE() GTEST_NONFATAL_FAILURE_("Failed") + +// Generates a nonfatal failure at the given source file location with +// a generic message. +#define ADD_FAILURE_AT(file, line) \ + GTEST_MESSAGE_AT_(file, line, "Failed", \ + ::testing::TestPartResult::kNonFatalFailure) + +// Generates a fatal failure with a generic message. +#define GTEST_FAIL() GTEST_FATAL_FAILURE_("Failed") + +// Define this macro to 1 to omit the definition of FAIL(), which is a +// generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_FAIL +# define FAIL() GTEST_FAIL() +#endif + +// Generates a success with a generic message. +#define GTEST_SUCCEED() GTEST_SUCCESS_("Succeeded") + +// Define this macro to 1 to omit the definition of SUCCEED(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_SUCCEED +# define SUCCEED() GTEST_SUCCEED() +#endif + +// Macros for testing exceptions. +// +// * {ASSERT|EXPECT}_THROW(statement, expected_exception): +// Tests that the statement throws the expected exception. +// * {ASSERT|EXPECT}_NO_THROW(statement): +// Tests that the statement doesn't throw any exception. +// * {ASSERT|EXPECT}_ANY_THROW(statement): +// Tests that the statement throws an exception. + +#define EXPECT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_NONFATAL_FAILURE_) +#define EXPECT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define EXPECT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define ASSERT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_FATAL_FAILURE_) +#define ASSERT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_FATAL_FAILURE_) +#define ASSERT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_FATAL_FAILURE_) + +// Boolean assertions. Condition can be either a Boolean expression or an +// AssertionResult. For more information on how to use AssertionResult with +// these macros see comments on that class. +#define EXPECT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_NONFATAL_FAILURE_) +#define EXPECT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_NONFATAL_FAILURE_) +#define ASSERT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_FATAL_FAILURE_) +#define ASSERT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_FATAL_FAILURE_) + +// Includes the auto-generated header that implements a family of +// generic predicate assertion macros. +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file is AUTOMATICALLY GENERATED on 09/24/2010 by command +// 'gen_gtest_pred_impl.py 5'. DO NOT EDIT BY HAND! +// +// Implements a family of generic predicate assertion macros. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ + +// Makes sure this header is not included before gtest.h. +#ifndef GTEST_INCLUDE_GTEST_GTEST_H_ +# error Do not include gtest_pred_impl.h directly. Include gtest.h instead. +#endif // GTEST_INCLUDE_GTEST_GTEST_H_ + +// This header implements a family of generic predicate assertion +// macros: +// +// ASSERT_PRED_FORMAT1(pred_format, v1) +// ASSERT_PRED_FORMAT2(pred_format, v1, v2) +// ... +// +// where pred_format is a function or functor that takes n (in the +// case of ASSERT_PRED_FORMATn) values and their source expression +// text, and returns a testing::AssertionResult. See the definition +// of ASSERT_EQ in gtest.h for an example. +// +// If you don't care about formatting, you can use the more +// restrictive version: +// +// ASSERT_PRED1(pred, v1) +// ASSERT_PRED2(pred, v1, v2) +// ... +// +// where pred is an n-ary function or functor that returns bool, +// and the values v1, v2, ..., must support the << operator for +// streaming to std::ostream. +// +// We also define the EXPECT_* variations. +// +// For now we only support predicates whose arity is at most 5. +// Please email googletestframework@googlegroups.com if you need +// support for higher arities. + +// GTEST_ASSERT_ is the basic statement to which all of the assertions +// in this file reduce. Don't use this in your code. + +#define GTEST_ASSERT_(expression, on_failure) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar = (expression)) \ + ; \ + else \ + on_failure(gtest_ar.failure_message()) + + +// Helper function for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +template +AssertionResult AssertPred1Helper(const char* pred_text, + const char* e1, + Pred pred, + const T1& v1) { + if (pred(v1)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT1. +// Don't use this in your code. +#define GTEST_PRED_FORMAT1_(pred_format, v1, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, v1),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +#define GTEST_PRED1_(pred, v1, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred1Helper(#pred, \ + #v1, \ + pred, \ + v1), on_failure) + +// Unary predicate assertion macros. +#define EXPECT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +template +AssertionResult AssertPred2Helper(const char* pred_text, + const char* e1, + const char* e2, + Pred pred, + const T1& v1, + const T2& v2) { + if (pred(v1, v2)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT2. +// Don't use this in your code. +#define GTEST_PRED_FORMAT2_(pred_format, v1, v2, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, v1, v2),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +#define GTEST_PRED2_(pred, v1, v2, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred2Helper(#pred, \ + #v1, \ + #v2, \ + pred, \ + v1, \ + v2), on_failure) + +// Binary predicate assertion macros. +#define EXPECT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +template +AssertionResult AssertPred3Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3) { + if (pred(v1, v2, v3)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT3. +// Don't use this in your code. +#define GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, v1, v2, v3),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +#define GTEST_PRED3_(pred, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred3Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + pred, \ + v1, \ + v2, \ + v3), on_failure) + +// Ternary predicate assertion macros. +#define EXPECT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +template +AssertionResult AssertPred4Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) { + if (pred(v1, v2, v3, v4)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ", " + << e4 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3 + << "\n" << e4 << " evaluates to " << v4; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT4. +// Don't use this in your code. +#define GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, v1, v2, v3, v4),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +#define GTEST_PRED4_(pred, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred4Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4), on_failure) + +// 4-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +template +AssertionResult AssertPred5Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + const char* e5, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) { + if (pred(v1, v2, v3, v4, v5)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ", " + << e4 << ", " + << e5 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3 + << "\n" << e4 << " evaluates to " << v4 + << "\n" << e5 << " evaluates to " << v5; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT5. +// Don't use this in your code. +#define GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, #v5, v1, v2, v3, v4, v5),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +#define GTEST_PRED5_(pred, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred5Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + #v5, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4, \ + v5), on_failure) + +// 5-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) + + + +#endif // GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ + +// Macros for testing equalities and inequalities. +// +// * {ASSERT|EXPECT}_EQ(expected, actual): Tests that expected == actual +// * {ASSERT|EXPECT}_NE(v1, v2): Tests that v1 != v2 +// * {ASSERT|EXPECT}_LT(v1, v2): Tests that v1 < v2 +// * {ASSERT|EXPECT}_LE(v1, v2): Tests that v1 <= v2 +// * {ASSERT|EXPECT}_GT(v1, v2): Tests that v1 > v2 +// * {ASSERT|EXPECT}_GE(v1, v2): Tests that v1 >= v2 +// +// When they are not, Google Test prints both the tested expressions and +// their actual values. The values must be compatible built-in types, +// or you will get a compiler error. By "compatible" we mean that the +// values can be compared by the respective operator. +// +// Note: +// +// 1. It is possible to make a user-defined type work with +// {ASSERT|EXPECT}_??(), but that requires overloading the +// comparison operators and is thus discouraged by the Google C++ +// Usage Guide. Therefore, you are advised to use the +// {ASSERT|EXPECT}_TRUE() macro to assert that two objects are +// equal. +// +// 2. The {ASSERT|EXPECT}_??() macros do pointer comparisons on +// pointers (in particular, C strings). Therefore, if you use it +// with two C strings, you are testing how their locations in memory +// are related, not how their content is related. To compare two C +// strings by content, use {ASSERT|EXPECT}_STR*(). +// +// 3. {ASSERT|EXPECT}_EQ(expected, actual) is preferred to +// {ASSERT|EXPECT}_TRUE(expected == actual), as the former tells you +// what the actual value is when it fails, and similarly for the +// other comparisons. +// +// 4. Do not depend on the order in which {ASSERT|EXPECT}_??() +// evaluate their arguments, which is undefined. +// +// 5. These macros evaluate their arguments exactly once. +// +// Examples: +// +// EXPECT_NE(5, Foo()); +// EXPECT_EQ(NULL, a_pointer); +// ASSERT_LT(i, array_size); +// ASSERT_GT(records.size(), 0) << "There is no record left."; + +#define EXPECT_EQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal:: \ + EqHelper::Compare, \ + expected, actual) +#define EXPECT_NE(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperNE, expected, actual) +#define EXPECT_LE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define EXPECT_LT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define EXPECT_GE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define EXPECT_GT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +#define GTEST_ASSERT_EQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal:: \ + EqHelper::Compare, \ + expected, actual) +#define GTEST_ASSERT_NE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) +#define GTEST_ASSERT_LE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define GTEST_ASSERT_LT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define GTEST_ASSERT_GE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define GTEST_ASSERT_GT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +// Define macro GTEST_DONT_DEFINE_ASSERT_XY to 1 to omit the definition of +// ASSERT_XY(), which clashes with some users' own code. + +#if !GTEST_DONT_DEFINE_ASSERT_EQ +# define ASSERT_EQ(val1, val2) GTEST_ASSERT_EQ(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_NE +# define ASSERT_NE(val1, val2) GTEST_ASSERT_NE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LE +# define ASSERT_LE(val1, val2) GTEST_ASSERT_LE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LT +# define ASSERT_LT(val1, val2) GTEST_ASSERT_LT(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GE +# define ASSERT_GE(val1, val2) GTEST_ASSERT_GE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GT +# define ASSERT_GT(val1, val2) GTEST_ASSERT_GT(val1, val2) +#endif + +// C String Comparisons. All tests treat NULL and any non-NULL string +// as different. Two NULLs are equal. +// +// * {ASSERT|EXPECT}_STREQ(s1, s2): Tests that s1 == s2 +// * {ASSERT|EXPECT}_STRNE(s1, s2): Tests that s1 != s2 +// * {ASSERT|EXPECT}_STRCASEEQ(s1, s2): Tests that s1 == s2, ignoring case +// * {ASSERT|EXPECT}_STRCASENE(s1, s2): Tests that s1 != s2, ignoring case +// +// For wide or narrow string objects, you can use the +// {ASSERT|EXPECT}_??() macros. +// +// Don't depend on the order in which the arguments are evaluated, +// which is undefined. +// +// These macros evaluate their arguments exactly once. + +#define EXPECT_STREQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, expected, actual) +#define EXPECT_STRNE(s1, s2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define EXPECT_STRCASEEQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, expected, actual) +#define EXPECT_STRCASENE(s1, s2)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +#define ASSERT_STREQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, expected, actual) +#define ASSERT_STRNE(s1, s2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define ASSERT_STRCASEEQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, expected, actual) +#define ASSERT_STRCASENE(s1, s2)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +// Macros for comparing floating-point numbers. +// +// * {ASSERT|EXPECT}_FLOAT_EQ(expected, actual): +// Tests that two float values are almost equal. +// * {ASSERT|EXPECT}_DOUBLE_EQ(expected, actual): +// Tests that two double values are almost equal. +// * {ASSERT|EXPECT}_NEAR(v1, v2, abs_error): +// Tests that v1 and v2 are within the given distance to each other. +// +// Google Test uses ULP-based comparison to automatically pick a default +// error bound that is appropriate for the operands. See the +// FloatingPoint template class in gtest-internal.h if you are +// interested in the implementation details. + +#define EXPECT_FLOAT_EQ(expected, actual)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define EXPECT_DOUBLE_EQ(expected, actual)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define ASSERT_FLOAT_EQ(expected, actual)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define ASSERT_DOUBLE_EQ(expected, actual)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define EXPECT_NEAR(val1, val2, abs_error)\ + EXPECT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +#define ASSERT_NEAR(val1, val2, abs_error)\ + ASSERT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +// These predicate format functions work on floating-point values, and +// can be used in {ASSERT|EXPECT}_PRED_FORMAT2*(), e.g. +// +// EXPECT_PRED_FORMAT2(testing::DoubleLE, Foo(), 5.0); + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +GTEST_API_ AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2); +GTEST_API_ AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2); + + +#if GTEST_OS_WINDOWS + +// Macros that test for HRESULT failure and success, these are only useful +// on Windows, and rely on Windows SDK macros and APIs to compile. +// +// * {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED}(expr) +// +// When expr unexpectedly fails or succeeds, Google Test prints the +// expected result and the actual result with both a human-readable +// string representation of the error, if available, as well as the +// hex result code. +# define EXPECT_HRESULT_SUCCEEDED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define ASSERT_HRESULT_SUCCEEDED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define EXPECT_HRESULT_FAILED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +# define ASSERT_HRESULT_FAILED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +#endif // GTEST_OS_WINDOWS + +// Macros that execute statement and check that it doesn't generate new fatal +// failures in the current thread. +// +// * {ASSERT|EXPECT}_NO_FATAL_FAILURE(statement); +// +// Examples: +// +// EXPECT_NO_FATAL_FAILURE(Process()); +// ASSERT_NO_FATAL_FAILURE(Process()) << "Process() failed"; +// +#define ASSERT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_FATAL_FAILURE_) +#define EXPECT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_NONFATAL_FAILURE_) + +// Causes a trace (including the source file path, the current line +// number, and the given message) to be included in every test failure +// message generated by code in the current scope. The effect is +// undone when the control leaves the current scope. +// +// The message argument can be anything streamable to std::ostream. +// +// In the implementation, we include the current line number as part +// of the dummy variable name, thus allowing multiple SCOPED_TRACE()s +// to appear in the same block - as long as they are on different +// lines. +#define SCOPED_TRACE(message) \ + ::testing::internal::ScopedTrace GTEST_CONCAT_TOKEN_(gtest_trace_, __LINE__)(\ + __FILE__, __LINE__, ::testing::Message() << (message)) + +// Compile-time assertion for type equality. +// StaticAssertTypeEq() compiles iff type1 and type2 are +// the same type. The value it returns is not interesting. +// +// Instead of making StaticAssertTypeEq a class template, we make it a +// function template that invokes a helper class template. This +// prevents a user from misusing StaticAssertTypeEq by +// defining objects of that type. +// +// CAVEAT: +// +// When used inside a method of a class template, +// StaticAssertTypeEq() is effective ONLY IF the method is +// instantiated. For example, given: +// +// template class Foo { +// public: +// void Bar() { testing::StaticAssertTypeEq(); } +// }; +// +// the code: +// +// void Test1() { Foo foo; } +// +// will NOT generate a compiler error, as Foo::Bar() is never +// actually instantiated. Instead, you need: +// +// void Test2() { Foo foo; foo.Bar(); } +// +// to cause a compiler error. +template +bool StaticAssertTypeEq() { + (void)internal::StaticAssertTypeEqHelper(); + return true; +} + +// Defines a test. +// +// The first parameter is the name of the test case, and the second +// parameter is the name of the test within the test case. +// +// The convention is to end the test case name with "Test". For +// example, a test case for the Foo class can be named FooTest. +// +// The user should put his test code between braces after using this +// macro. Example: +// +// TEST(FooTest, InitializesCorrectly) { +// Foo foo; +// EXPECT_TRUE(foo.StatusIsOK()); +// } + +// Note that we call GetTestTypeId() instead of GetTypeId< +// ::testing::Test>() here to get the type ID of testing::Test. This +// is to work around a suspected linker bug when using Google Test as +// a framework on Mac OS X. The bug causes GetTypeId< +// ::testing::Test>() to return different values depending on whether +// the call is from the Google Test framework itself or from user test +// code. GetTestTypeId() is guaranteed to always return the same +// value, as it always calls GetTypeId<>() from the Google Test +// framework. +#define GTEST_TEST(test_case_name, test_name)\ + GTEST_TEST_(test_case_name, test_name, \ + ::testing::Test, ::testing::internal::GetTestTypeId()) + +// Define this macro to 1 to omit the definition of TEST(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_TEST +# define TEST(test_case_name, test_name) GTEST_TEST(test_case_name, test_name) +#endif + +// Defines a test that uses a test fixture. +// +// The first parameter is the name of the test fixture class, which +// also doubles as the test case name. The second parameter is the +// name of the test within the test case. +// +// A test fixture class must be declared earlier. The user should put +// his test code between braces after using this macro. Example: +// +// class FooTest : public testing::Test { +// protected: +// virtual void SetUp() { b_.AddElement(3); } +// +// Foo a_; +// Foo b_; +// }; +// +// TEST_F(FooTest, InitializesCorrectly) { +// EXPECT_TRUE(a_.StatusIsOK()); +// } +// +// TEST_F(FooTest, ReturnsElementCountCorrectly) { +// EXPECT_EQ(0, a_.size()); +// EXPECT_EQ(1, b_.size()); +// } + +#define TEST_F(test_fixture, test_name)\ + GTEST_TEST_(test_fixture, test_name, test_fixture, \ + ::testing::internal::GetTypeId()) + +// Use this macro in main() to run all tests. It returns 0 if all +// tests are successful, or 1 otherwise. +// +// RUN_ALL_TESTS() should be invoked after the command line has been +// parsed by InitGoogleTest(). + +#define RUN_ALL_TESTS()\ + (::testing::UnitTest::GetInstance()->Run()) + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_H_ diff --git a/caffe-crfrnn/src/gtest/gtest_main.cc b/caffe-crfrnn/src/gtest/gtest_main.cc new file mode 100644 index 00000000..a09bbe0c --- /dev/null +++ b/caffe-crfrnn/src/gtest/gtest_main.cc @@ -0,0 +1,39 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "gtest/gtest.h" + +GTEST_API_ int main(int argc, char **argv) { + std::cout << "Running main() from gtest_main.cc\n"; + + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/caffe-crfrnn/tools/CMakeLists.txt b/caffe-crfrnn/tools/CMakeLists.txt new file mode 100644 index 00000000..02fbd5ca --- /dev/null +++ b/caffe-crfrnn/tools/CMakeLists.txt @@ -0,0 +1,29 @@ +# Collect source files +file(GLOB_RECURSE srcs ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +# Build each source file independently +foreach(source ${srcs}) + get_filename_component(name ${source} NAME_WE) + + # caffe target already exits + if(name MATCHES "caffe") + set(name ${name}.bin) + endif() + + # target + add_executable(${name} ${source}) + target_link_libraries(${name} ${Caffe_LINK}) + caffe_default_properties(${name}) + + # set back RUNTIME_OUTPUT_DIRECTORY + caffe_set_runtime_directory(${name} "${PROJECT_BINARY_DIR}/tools") + caffe_set_solution_folder(${name} tools) + + # restore output name without suffix + if(name MATCHES "caffe.bin") + set_target_properties(${name} PROPERTIES OUTPUT_NAME caffe) + endif() + + # Install + install(TARGETS ${name} DESTINATION bin) +endforeach(source) diff --git a/caffe-crfrnn/tools/caffe.cpp b/caffe-crfrnn/tools/caffe.cpp new file mode 100644 index 00000000..9f9d975d --- /dev/null +++ b/caffe-crfrnn/tools/caffe.cpp @@ -0,0 +1,303 @@ +#include + +#include +#include +#include +#include + +#include "caffe/caffe.hpp" + +using caffe::Blob; +using caffe::Caffe; +using caffe::Net; +using caffe::Layer; +using caffe::shared_ptr; +using caffe::Timer; +using caffe::vector; + + +DEFINE_int32(gpu, -1, + "Run in GPU mode on given device ID."); +DEFINE_string(solver, "", + "The solver definition protocol buffer text file."); +DEFINE_string(model, "", + "The model definition protocol buffer text file.."); +DEFINE_string(snapshot, "", + "Optional; the snapshot solver state to resume training."); +DEFINE_string(weights, "", + "Optional; the pretrained weights to initialize finetuning. " + "Cannot be set simultaneously with snapshot."); +DEFINE_int32(iterations, 50, + "The number of iterations to run."); + +// A simple registry for caffe commands. +typedef int (*BrewFunction)(); +typedef std::map BrewMap; +BrewMap g_brew_map; + +#define RegisterBrewFunction(func) \ +namespace { \ +class __Registerer_##func { \ + public: /* NOLINT */ \ + __Registerer_##func() { \ + g_brew_map[#func] = &func; \ + } \ +}; \ +__Registerer_##func g_registerer_##func; \ +} + +static BrewFunction GetBrewFunction(const caffe::string& name) { + if (g_brew_map.count(name)) { + return g_brew_map[name]; + } else { + LOG(ERROR) << "Available caffe actions:"; + for (BrewMap::iterator it = g_brew_map.begin(); + it != g_brew_map.end(); ++it) { + LOG(ERROR) << "\t" << it->first; + } + LOG(FATAL) << "Unknown action: " << name; + return NULL; // not reachable, just to suppress old compiler warnings. + } +} + +// caffe commands to call by +// caffe +// +// To add a command, define a function "int command()" and register it with +// RegisterBrewFunction(action); + +// Device Query: show diagnostic information for a GPU device. +int device_query() { + CHECK_GT(FLAGS_gpu, -1) << "Need a device ID to query."; + LOG(INFO) << "Querying device ID = " << FLAGS_gpu; + caffe::Caffe::SetDevice(FLAGS_gpu); + caffe::Caffe::DeviceQuery(); + return 0; +} +RegisterBrewFunction(device_query); + + +// Train / Finetune a model. +int train() { + CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; + CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) + << "Give a snapshot to resume training or weights to finetune " + "but not both."; + + caffe::SolverParameter solver_param; + caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); + + // If the gpu flag is not provided, allow the mode and device to be set + // in the solver prototxt. + if (FLAGS_gpu < 0 + && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { + FLAGS_gpu = solver_param.device_id(); + } + + // Set device id and mode + if (FLAGS_gpu >= 0) { + LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; + Caffe::SetDevice(FLAGS_gpu); + Caffe::set_mode(Caffe::GPU); + } else { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } + + LOG(INFO) << "Starting Optimization"; + shared_ptr > + solver(caffe::GetSolver(solver_param)); + + if (FLAGS_snapshot.size()) { + LOG(INFO) << "Resuming from " << FLAGS_snapshot; + solver->Solve(FLAGS_snapshot); + } else if (FLAGS_weights.size()) { + LOG(INFO) << "Finetuning from " << FLAGS_weights; + solver->net()->CopyTrainedLayersFrom(FLAGS_weights); + solver->Solve(); + } else { + solver->Solve(); + } + LOG(INFO) << "Optimization Done."; + return 0; +} +RegisterBrewFunction(train); + + +// Test: score a model. +int test() { + CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to score."; + CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; + + // Set device id and mode + if (FLAGS_gpu >= 0) { + LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; + Caffe::SetDevice(FLAGS_gpu); + Caffe::set_mode(Caffe::GPU); + } else { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } + // Instantiate the caffe net. + Caffe::set_phase(Caffe::TEST); + Net caffe_net(FLAGS_model); + caffe_net.CopyTrainedLayersFrom(FLAGS_weights); + LOG(INFO) << "Running for " << FLAGS_iterations << " iterations."; + + vector* > bottom_vec; + vector test_score_output_id; + vector test_score; + float loss = 0; + for (int i = 0; i < FLAGS_iterations; ++i) { + float iter_loss; + const vector*>& result = + caffe_net.Forward(bottom_vec, &iter_loss); + loss += iter_loss; + int idx = 0; + for (int j = 0; j < result.size(); ++j) { + const float* result_vec = result[j]->cpu_data(); + for (int k = 0; k < result[j]->count(); ++k, ++idx) { + const float score = result_vec[k]; + if (i == 0) { + test_score.push_back(score); + test_score_output_id.push_back(j); + } else { + test_score[idx] += score; + } + const std::string& output_name = caffe_net.blob_names()[ + caffe_net.output_blob_indices()[j]]; + LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score; + } + } + } + loss /= FLAGS_iterations; + LOG(INFO) << "Loss: " << loss; + for (int i = 0; i < test_score.size(); ++i) { + const std::string& output_name = caffe_net.blob_names()[ + caffe_net.output_blob_indices()[test_score_output_id[i]]]; + const float loss_weight = + caffe_net.blob_loss_weights()[caffe_net.output_blob_indices()[i]]; + std::ostringstream loss_msg_stream; + const float mean_score = test_score[i] / FLAGS_iterations; + if (loss_weight) { + loss_msg_stream << " (* " << loss_weight + << " = " << loss_weight * mean_score << " loss)"; + } + LOG(INFO) << output_name << " = " << mean_score << loss_msg_stream.str(); + } + + return 0; +} +RegisterBrewFunction(test); + + +// Time: benchmark the execution time of a model. +int time() { + CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; + + // Set device id and mode + if (FLAGS_gpu >= 0) { + LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; + Caffe::SetDevice(FLAGS_gpu); + Caffe::set_mode(Caffe::GPU); + } else { + LOG(INFO) << "Use CPU."; + Caffe::set_mode(Caffe::CPU); + } + // Instantiate the caffe net. + Caffe::set_phase(Caffe::TRAIN); + Net caffe_net(FLAGS_model); + + // Do a clean forward and backward pass, so that memory allocation are done + // and future iterations will be more stable. + LOG(INFO) << "Performing Forward"; + // Note that for the speed benchmark, we will assume that the network does + // not take any input blobs. + float initial_loss; + caffe_net.Forward(vector*>(), &initial_loss); + LOG(INFO) << "Initial loss: " << initial_loss; + LOG(INFO) << "Performing Backward"; + caffe_net.Backward(); + + const vector > >& layers = caffe_net.layers(); + vector*> >& bottom_vecs = caffe_net.bottom_vecs(); + vector*> >& top_vecs = caffe_net.top_vecs(); + const vector >& bottom_need_backward = + caffe_net.bottom_need_backward(); + LOG(INFO) << "*** Benchmark begins ***"; + LOG(INFO) << "Testing for " << FLAGS_iterations << " iterations."; + Timer total_timer; + total_timer.Start(); + Timer forward_timer; + Timer backward_timer; + Timer timer; + std::vector forward_time_per_layer(layers.size(), 0.0); + std::vector backward_time_per_layer(layers.size(), 0.0); + double forward_time = 0.0; + double backward_time = 0.0; + for (int j = 0; j < FLAGS_iterations; ++j) { + Timer iter_timer; + iter_timer.Start(); + forward_timer.Start(); + for (int i = 0; i < layers.size(); ++i) { + timer.Start(); + // Although Reshape should be essentially free, we include it here + // so that we will notice Reshape performance bugs. + layers[i]->Reshape(bottom_vecs[i], top_vecs[i]); + layers[i]->Forward(bottom_vecs[i], top_vecs[i]); + forward_time_per_layer[i] += timer.MicroSeconds(); + } + forward_time += forward_timer.MicroSeconds(); + backward_timer.Start(); + for (int i = layers.size() - 1; i >= 0; --i) { + timer.Start(); + layers[i]->Backward(top_vecs[i], bottom_need_backward[i], + bottom_vecs[i]); + backward_time_per_layer[i] += timer.MicroSeconds(); + } + backward_time += backward_timer.MicroSeconds(); + LOG(INFO) << "Iteration: " << j + 1 << " forward-backward time: " + << iter_timer.MilliSeconds() << " ms."; + } + LOG(INFO) << "Average time per layer: "; + for (int i = 0; i < layers.size(); ++i) { + const caffe::string& layername = layers[i]->layer_param().name(); + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tforward: " << forward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << std::setfill(' ') << std::setw(10) << layername << + "\tbackward: " << backward_time_per_layer[i] / 1000 / + FLAGS_iterations << " ms."; + } + total_timer.Stop(); + LOG(INFO) << "Average Forward pass: " << forward_time / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << "Average Backward pass: " << backward_time / 1000 / + FLAGS_iterations << " ms."; + LOG(INFO) << "Average Forward-Backward: " << total_timer.MilliSeconds() / + FLAGS_iterations << " ms."; + LOG(INFO) << "Total Time: " << total_timer.MilliSeconds() << " ms."; + LOG(INFO) << "*** Benchmark ends ***"; + return 0; +} +RegisterBrewFunction(time); + +int main(int argc, char** argv) { + // Print output to stderr (while still logging). + FLAGS_alsologtostderr = 1; + // Usage message. + gflags::SetUsageMessage("command line brew\n" + "usage: caffe \n\n" + "commands:\n" + " train train or finetune a model\n" + " test score a model\n" + " device_query show GPU diagnostic information\n" + " time benchmark model execution time"); + // Run tool or show usage. + caffe::GlobalInit(&argc, &argv); + if (argc == 2) { + return GetBrewFunction(caffe::string(argv[1]))(); + } else { + gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); + } +} diff --git a/caffe-crfrnn/tools/compute_image_mean.cpp b/caffe-crfrnn/tools/compute_image_mean.cpp new file mode 100644 index 00000000..358f57e3 --- /dev/null +++ b/caffe-crfrnn/tools/compute_image_mean.cpp @@ -0,0 +1,123 @@ +#include +#include +#include + +#include +#include +#include +#include + +#include "caffe/dataset_factory.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" + +using caffe::Dataset; +using caffe::Datum; +using caffe::BlobProto; +using std::max; +using std::pair; + + +DEFINE_string(backend, "lmdb", "The backend for containing the images"); + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + +#ifndef GFLAGS_GFLAGS_H_ + namespace gflags = google; +#endif + + gflags::SetUsageMessage("Compute the mean_image of a set of images given by" + " a leveldb/lmdb or a list of images\n" + "Usage:\n" + " compute_image_mean [FLAGS] INPUT_DB [OUTPUT_FILE]\n"); + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (argc < 2 || argc > 3) { + gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/compute_image_mean"); + return 1; + } + + std::string db_backend = FLAGS_backend; + + caffe::shared_ptr > dataset = + caffe::DatasetFactory(db_backend); + + // Open db + CHECK(dataset->open(argv[1], Dataset::ReadOnly)); + + BlobProto sum_blob; + int count = 0; + // load first datum + Dataset::const_iterator iter = dataset->begin(); + Datum datum = iter->value; + + if (DecodeDatum(&datum)) { + LOG(INFO) << "Decoding Datum"; + } + + sum_blob.set_num(1); + sum_blob.set_channels(datum.channels()); + sum_blob.set_height(datum.height()); + sum_blob.set_width(datum.width()); + const int data_size = datum.channels() * datum.height() * datum.width(); + int size_in_datum = std::max(datum.data().size(), + datum.float_data_size()); + for (int i = 0; i < size_in_datum; ++i) { + sum_blob.add_data(0.); + } + LOG(INFO) << "Starting Iteration"; + for (Dataset::const_iterator iter = dataset->begin(); + iter != dataset->end(); ++iter) { + Datum datum = iter->value; + DecodeDatum(&datum); + + const std::string& data = datum.data(); + size_in_datum = std::max(datum.data().size(), + datum.float_data_size()); + CHECK_EQ(size_in_datum, data_size) << "Incorrect data field size " << + size_in_datum; + if (data.size() != 0) { + CHECK_EQ(data.size(), size_in_datum); + for (int i = 0; i < size_in_datum; ++i) { + sum_blob.set_data(i, sum_blob.data(i) + (uint8_t)data[i]); + } + } else { + CHECK_EQ(datum.float_data_size(), size_in_datum); + for (int i = 0; i < size_in_datum; ++i) { + sum_blob.set_data(i, sum_blob.data(i) + + static_cast(datum.float_data(i))); + } + } + ++count; + if (count % 10000 == 0) { + LOG(INFO) << "Processed " << count << " files."; + } + } + + if (count % 10000 != 0) { + LOG(INFO) << "Processed " << count << " files."; + } + for (int i = 0; i < sum_blob.data_size(); ++i) { + sum_blob.set_data(i, sum_blob.data(i) / count); + } + // Write to disk + if (argc == 3) { + LOG(INFO) << "Write to " << argv[2]; + WriteProtoToBinaryFile(sum_blob, argv[2]); + } + const int channels = sum_blob.channels(); + const int dim = sum_blob.height() * sum_blob.width(); + std::vector mean_values(channels, 0.0); + LOG(INFO) << "Number of channels: " << channels; + for (int c = 0; c < channels; ++c) { + for (int i = 0; i < dim; ++i) { + mean_values[c] += sum_blob.data(dim * c + i); + } + LOG(INFO) << "mean_value channel [" << c << "]:" << mean_values[c] / dim; + } + // Clean up + dataset->close(); + return 0; +} diff --git a/caffe-crfrnn/tools/convert_imageset.cpp b/caffe-crfrnn/tools/convert_imageset.cpp new file mode 100644 index 00000000..c554ed38 --- /dev/null +++ b/caffe-crfrnn/tools/convert_imageset.cpp @@ -0,0 +1,146 @@ +// This program converts a set of images to a lmdb/leveldb by storing them +// as Datum proto buffers. +// Usage: +// convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME +// +// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE +// should be a list of files as well as their labels, in the format as +// subfolder1/file1.JPEG 7 +// .... + +#include +#include + +#include +#include // NOLINT(readability/streams) +#include +#include +#include + +#include "caffe/dataset_factory.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/util/rng.hpp" + +using namespace caffe; // NOLINT(build/namespaces) +using std::pair; + +DEFINE_bool(gray, false, + "When this option is on, treat images as grayscale ones"); +DEFINE_bool(shuffle, false, + "Randomly shuffle the order of images and their labels"); +DEFINE_string(backend, "lmdb", "The backend for storing the result"); +DEFINE_int32(resize_width, 0, "Width images are resized to"); +DEFINE_int32(resize_height, 0, "Height images are resized to"); +DEFINE_bool(check_size, false, + "When this option is on, check that all the datum have the same size"); +DEFINE_bool(encoded, false, + "When this option is on, the encoded image will be save in datum"); + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + +#ifndef GFLAGS_GFLAGS_H_ + namespace gflags = google; +#endif + + gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n" + "format used as input for Caffe.\n" + "Usage:\n" + " convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n" + "The ImageNet dataset for the training demo is at\n" + " http://www.image-net.org/download-images\n"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (argc != 4) { + gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset"); + return 1; + } + + const bool is_color = !FLAGS_gray; + const bool check_size = FLAGS_check_size; + const bool encoded = FLAGS_encoded; + + std::ifstream infile(argv[2]); + std::vector > lines; + std::string filename; + int label; + while (infile >> filename >> label) { + lines.push_back(std::make_pair(filename, label)); + } + if (FLAGS_shuffle) { + // randomly shuffle data + LOG(INFO) << "Shuffling data"; + shuffle(lines.begin(), lines.end()); + } + LOG(INFO) << "A total of " << lines.size() << " images."; + + const std::string& db_backend = FLAGS_backend; + const char* db_path = argv[3]; + + if (encoded) { + CHECK_EQ(FLAGS_resize_height, 0) << "With encoded don't resize images"; + CHECK_EQ(FLAGS_resize_width, 0) << "With encoded don't resize images"; + CHECK(!check_size) << "With encoded cannot check_size"; + } + + int resize_height = std::max(0, FLAGS_resize_height); + int resize_width = std::max(0, FLAGS_resize_width); + + // Open new db + shared_ptr > dataset = + DatasetFactory(db_backend); + + // Open db + CHECK(dataset->open(db_path, Dataset::New)); + + // Storing to db + std::string root_folder(argv[1]); + Datum datum; + int count = 0; + const int kMaxKeyLength = 256; + char key_cstr[kMaxKeyLength]; + int data_size; + bool data_size_initialized = false; + + for (int line_id = 0; line_id < lines.size(); ++line_id) { + bool status; + if (encoded) { + status = ReadFileToDatum(root_folder + lines[line_id].first, + lines[line_id].second, &datum); + } else { + status = ReadImageToDatum(root_folder + lines[line_id].first, + lines[line_id].second, resize_height, resize_width, is_color, &datum); + } + if (status == false) continue; + if (check_size) { + if (!data_size_initialized) { + data_size = datum.channels() * datum.height() * datum.width(); + data_size_initialized = true; + } else { + const std::string& data = datum.data(); + CHECK_EQ(data.size(), data_size) << "Incorrect data field size " + << data.size(); + } + } + // sequential + int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id, + lines[line_id].first.c_str()); + + // Put in db + CHECK(dataset->put(string(key_cstr, length), datum)); + + if (++count % 1000 == 0) { + // Commit txn + CHECK(dataset->commit()); + LOG(ERROR) << "Processed " << count << " files."; + } + } + // write the last batch + if (count % 1000 != 0) { + CHECK(dataset->commit()); + LOG(ERROR) << "Processed " << count << " files."; + } + dataset->close(); + return 0; +} diff --git a/caffe-crfrnn/tools/device_query.cpp b/caffe-crfrnn/tools/device_query.cpp new file mode 100644 index 00000000..03799e52 --- /dev/null +++ b/caffe-crfrnn/tools/device_query.cpp @@ -0,0 +1,7 @@ +#include "caffe/common.hpp" + +int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe device_query " + "[--device_id=0] instead."; + return 0; +} diff --git a/caffe-crfrnn/tools/dump_network.cpp b/caffe-crfrnn/tools/dump_network.cpp new file mode 100644 index 00000000..9cb996ef --- /dev/null +++ b/caffe-crfrnn/tools/dump_network.cpp @@ -0,0 +1,82 @@ +// This program takes in a trained network and an input blob, and then dumps +// all the intermediate blobs produced by the net to individual binary +// files stored in protobuffer binary formats. +// Usage: +// dump_network input_net_param trained_net_param +// input_blob output_prefix 0/1 +// if input_net_param is 'none', we will directly load the network from +// trained_net_param. If the last argv is 1, we will do a forward-backward pass +// before dumping everyting, and also dump the who network. + +#include +#include + +#include "fcntl.h" +#include "google/protobuf/text_format.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/util/io.hpp" + +using boost::shared_ptr; +using caffe::Blob; +using caffe::BlobProto; +using caffe::Caffe; +using caffe::Net; +using caffe::NetParameter; + +int main(int argc, char** argv) { + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TEST); + + shared_ptr > caffe_net; + if (strcmp(argv[1], "none") == 0) { + // We directly load the net param from trained file + caffe_net.reset(new Net(argv[2])); + } else { + caffe_net.reset(new Net(argv[1])); + } + caffe_net->CopyTrainedLayersFrom(argv[2]); + + std::vector* > input_vec; + shared_ptr > input_blob(new Blob()); + if (strcmp(argv[3], "none") != 0) { + BlobProto input_blob_proto; + ReadProtoFromBinaryFile(argv[3], &input_blob_proto); + input_blob->FromProto(input_blob_proto); + input_vec.push_back(input_blob.get()); + } + + std::string output_prefix(argv[4]); + // Run the network without training. + LOG(ERROR) << "Performing Forward"; + caffe_net->Forward(input_vec); + if (argc > 5 && strcmp(argv[5], "1") == 0) { + LOG(ERROR) << "Performing Backward"; + Caffe::set_phase(Caffe::TRAIN); + caffe_net->Backward(); + // Dump the network + NetParameter output_net_param; + caffe_net->ToProto(&output_net_param, true); + WriteProtoToBinaryFile(output_net_param, + output_prefix + output_net_param.name()); + } + // Now, let's dump all the layers + + const std::vector& blob_names = caffe_net->blob_names(); + const std::vector > >& blobs = caffe_net->blobs(); + for (int blobid = 0; blobid < caffe_net->blobs().size(); ++blobid) { + // Serialize blob + LOG(ERROR) << "Dumping " << blob_names[blobid]; + BlobProto output_blob_proto; + blobs[blobid]->ToProto(&output_blob_proto); + WriteProtoToBinaryFile(output_blob_proto, + output_prefix + blob_names[blobid]); + } + + return 0; +} diff --git a/caffe-crfrnn/tools/extra/extract_seconds.py b/caffe-crfrnn/tools/extra/extract_seconds.py new file mode 100755 index 00000000..591a51f9 --- /dev/null +++ b/caffe-crfrnn/tools/extra/extract_seconds.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +import datetime +import os +import sys + +def extract_datetime_from_line(line, year): + # Expected format: I0210 13:39:22.381027 25210 solver.cpp:204] Iteration 100, lr = 0.00992565 + line = line.strip().split() + month = int(line[0][1:3]) + day = int(line[0][3:]) + timestamp = line[1] + pos = timestamp.rfind('.') + ts = [int(x) for x in timestamp[:pos].split(':')] + hour = ts[0] + minute = ts[1] + second = ts[2] + microsecond = int(timestamp[pos + 1:]) + dt = datetime.datetime(year, month, day, hour, minute, second, microsecond) + return dt + + +def get_log_created_year(input_file): + """Get year from log file system timestamp + """ + + log_created_time = os.path.getctime(input_file) + log_created_year = datetime.datetime.fromtimestamp(log_created_time).year + return log_created_year + + +def get_start_time(line_iterable, year): + """Find start time from group of lines + """ + + start_datetime = None + for line in line_iterable: + line = line.strip() + if line.find('Solving') != -1: + start_datetime = extract_datetime_from_line(line, year) + break + return start_datetime + + +def extract_seconds(input_file, output_file): + with open(input_file, 'r') as f: + lines = f.readlines() + log_created_year = get_log_created_year(input_file) + start_datetime = get_start_time(lines, log_created_year) + assert start_datetime, 'Start time not found' + + out = open(output_file, 'w') + for line in lines: + line = line.strip() + if line.find('Iteration') != -1: + dt = extract_datetime_from_line(line, log_created_year) + elapsed_seconds = (dt - start_datetime).total_seconds() + out.write('%f\n' % elapsed_seconds) + out.close() + +if __name__ == '__main__': + if len(sys.argv) < 3: + print('Usage: ./extract_seconds input_file output_file') + exit(1) + extract_seconds(sys.argv[1], sys.argv[2]) diff --git a/caffe-crfrnn/tools/extra/launch_resize_and_crop_images.sh b/caffe-crfrnn/tools/extra/launch_resize_and_crop_images.sh new file mode 100755 index 00000000..84ca858c --- /dev/null +++ b/caffe-crfrnn/tools/extra/launch_resize_and_crop_images.sh @@ -0,0 +1,24 @@ +#!/bin/bash +#### https://github.com/Yangqing/mincepie/wiki/Launch-Your-Mapreducer + +# If you encounter error that the address already in use, kill the process. +# 11235 is the port of server process +# https://github.com/Yangqing/mincepie/blob/master/mincepie/mince.py +# sudo netstat -ap | grep 11235 +# The last column of the output is PID/Program name +# kill -9 PID +# Second solution: +# nmap localhost +# fuser -k 11235/tcp +# Or just wait a few seconds. + +## Launch your Mapreduce locally +# num_clients: number of processes +# image_lib: OpenCV or PIL, case insensitive. The default value is the faster OpenCV. +# input: the file containing one image path relative to input_folder each line +# input_folder: where are the original images +# output_folder: where to save the resized and cropped images +./resize_and_crop_images.py --num_clients=8 --image_lib=opencv --input=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images.txt --input_folder=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images_train/ --output_folder=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images_train_resized/ + +## Launch your Mapreduce with MPI +# mpirun -n 8 --launch=mpi resize_and_crop_images.py --image_lib=opencv --input=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images.txt --input_folder=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images_train/ --output_folder=/home/user/Datasets/ImageNet/ILSVRC2010/ILSVRC2010_images_train_resized/ diff --git a/caffe-crfrnn/tools/extra/parse_log.py b/caffe-crfrnn/tools/extra/parse_log.py new file mode 100755 index 00000000..16ba077a --- /dev/null +++ b/caffe-crfrnn/tools/extra/parse_log.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +""" +Parse training log + +Competitor to parse_log.sh +""" + +import os +import re +import extract_seconds +import argparse +import csv + + +def get_line_type(line): + """Return either 'test' or 'train' depending on line type + """ + + line_type = None + if line.find('Train') != -1: + line_type = 'train' + elif line.find('Test') != -1: + line_type = 'test' + return line_type + + +def parse_log(path_to_log): + """Parse log file + Returns (train_dict_list, train_dict_names, test_dict_list, test_dict_names) + + train_dict_list and test_dict_list are lists of dicts that define the table + rows + + train_dict_names and test_dict_names are ordered tuples of the column names + for the two dict_lists + """ + + re_iteration = re.compile('Iteration (\d+)') + re_accuracy = re.compile('output #\d+: accuracy = ([\.\d]+)') + re_train_loss = re.compile('Iteration \d+, loss = ([\.\d]+)') + re_output_loss = re.compile('output #\d+: loss = ([\.\d]+)') + re_lr = re.compile('lr = ([\.\d]+)') + + # Pick out lines of interest + iteration = -1 + test_accuracy = -1 + learning_rate = float('NaN') + train_dict_list = [] + test_dict_list = [] + train_dict_names = ('NumIters', 'Seconds', 'TrainingLoss', 'LearningRate') + test_dict_names = ('NumIters', 'Seconds', 'TestAccuracy', 'TestLoss') + + logfile_year = extract_seconds.get_log_created_year(path_to_log) + with open(path_to_log) as f: + start_time = extract_seconds.get_start_time(f, logfile_year) + + for line in f: + iteration_match = re_iteration.search(line) + if iteration_match: + iteration = float(iteration_match.group(1)) + if iteration == -1: + # Only look for other stuff if we've found the first iteration + continue + + time = extract_seconds.extract_datetime_from_line(line, + logfile_year) + seconds = (time - start_time).total_seconds() + + lr_match = re_lr.search(line) + if lr_match: + learning_rate = float(lr_match.group(1)) + + accuracy_match = re_accuracy.search(line) + if accuracy_match and get_line_type(line) == 'test': + test_accuracy = float(accuracy_match.group(1)) + + train_loss_match = re_train_loss.search(line) + if train_loss_match: + train_loss = float(train_loss_match.group(1)) + train_dict_list.append({'NumIters': iteration, + 'Seconds': seconds, + 'TrainingLoss': train_loss, + 'LearningRate': learning_rate}) + + output_loss_match = re_output_loss.search(line) + if output_loss_match and get_line_type(line) == 'test': + test_loss = float(output_loss_match.group(1)) + # NOTE: we assume that (1) accuracy always comes right before + # loss for test data so the test_accuracy variable is already + # correctly populated and (2) there's one and only one output + # named "accuracy" for the test net + test_dict_list.append({'NumIters': iteration, + 'Seconds': seconds, + 'TestAccuracy': test_accuracy, + 'TestLoss': test_loss}) + + return train_dict_list, train_dict_names, test_dict_list, test_dict_names + + +def save_csv_files(logfile_path, output_dir, train_dict_list, train_dict_names, + test_dict_list, test_dict_names, verbose=False): + """Save CSV files to output_dir + + If the input log file is, e.g., caffe.INFO, the names will be + caffe.INFO.train and caffe.INFO.test + """ + + log_basename = os.path.basename(logfile_path) + train_filename = os.path.join(output_dir, log_basename + '.train') + write_csv(train_filename, train_dict_list, train_dict_names, verbose) + + test_filename = os.path.join(output_dir, log_basename + '.test') + write_csv(test_filename, test_dict_list, test_dict_names, verbose) + + +def write_csv(output_filename, dict_list, header_names, verbose=False): + """Write a CSV file + """ + + with open(output_filename, 'w') as f: + dict_writer = csv.DictWriter(f, header_names) + dict_writer.writeheader() + dict_writer.writerows(dict_list) + if verbose: + print 'Wrote %s' % output_filename + + +def parse_args(): + description = ('Parse a Caffe training log into two CSV files ' + 'containing training and testing information') + parser = argparse.ArgumentParser(description=description) + + parser.add_argument('logfile_path', + help='Path to log file') + + parser.add_argument('output_dir', + help='Directory in which to place output CSV files') + + parser.add_argument('--verbose', + action='store_true', + help='Print some extra info (e.g., output filenames)') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + train_dict_list, train_dict_names, test_dict_list, test_dict_names = \ + parse_log(args.logfile_path) + save_csv_files(args.logfile_path, args.output_dir, train_dict_list, + train_dict_names, test_dict_list, test_dict_names) + + +if __name__ == '__main__': + main() diff --git a/caffe-crfrnn/tools/extra/parse_log.sh b/caffe-crfrnn/tools/extra/parse_log.sh new file mode 100755 index 00000000..98ef0a05 --- /dev/null +++ b/caffe-crfrnn/tools/extra/parse_log.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Usage parse_log.sh caffe.log +# It creates the following two text files, each containing a table: +# caffe.log.test (columns: '#Iters Seconds TestAccuracy TestLoss') +# caffe.log.train (columns: '#Iters Seconds TrainingLoss LearningRate') + + +# get the dirname of the script +DIR="$( cd "$(dirname "$0")" ; pwd -P )" + +if [ "$#" -lt 1 ] +then +echo "Usage parse_log.sh /path/to/your.log" +exit +fi +LOG=`basename $1` +grep -B 1 'Test ' $1 > aux.txt +grep 'Iteration ' aux.txt | sed 's/.*Iteration \([[:digit:]]*\).*/\1/g' > aux0.txt +grep 'Test net output #0' aux.txt | awk '{print $11}' > aux1.txt +grep 'Test net output #1' aux.txt | awk '{print $11}' > aux2.txt + +# Extracting elapsed seconds +# For extraction of time since this line contains the start time +grep '] Solving ' $1 > aux3.txt +grep 'Testing net' $1 >> aux3.txt +$DIR/extract_seconds.py aux3.txt aux4.txt + +# Generating +echo '#Iters Seconds TestAccuracy TestLoss'> $LOG.test +paste aux0.txt aux4.txt aux1.txt aux2.txt | column -t >> $LOG.test +rm aux.txt aux0.txt aux1.txt aux2.txt aux3.txt aux4.txt + +# For extraction of time since this line contains the start time +grep '] Solving ' $1 > aux.txt +grep ', loss = ' $1 >> aux.txt +grep 'Iteration ' aux.txt | sed 's/.*Iteration \([[:digit:]]*\).*/\1/g' > aux0.txt +grep ', loss = ' $1 | awk '{print $9}' > aux1.txt +grep ', lr = ' $1 | awk '{print $9}' > aux2.txt + +# Extracting elapsed seconds +$DIR/extract_seconds.py aux.txt aux3.txt + +# Generating +echo '#Iters Seconds TrainingLoss LearningRate'> $LOG.train +paste aux0.txt aux3.txt aux1.txt aux2.txt | column -t >> $LOG.train +rm aux.txt aux0.txt aux1.txt aux2.txt aux3.txt diff --git a/caffe-crfrnn/tools/extra/plot_log.gnuplot.example b/caffe-crfrnn/tools/extra/plot_log.gnuplot.example new file mode 100644 index 00000000..334ff1f2 --- /dev/null +++ b/caffe-crfrnn/tools/extra/plot_log.gnuplot.example @@ -0,0 +1,69 @@ +# These snippets serve only as basic examples. +# Customization is a must. +# You can copy, paste, edit them in whatever way you want. +# Be warned that the fields in the training log may change in the future. +# You had better check the data files before designing your own plots. + +# Please generate the neccessary data files with +# /path/to/caffe/tools/extra/parse_log.sh before plotting. +# Example usage: +# ./parse_log.sh mnist.log +# Now you have mnist.log.train and mnist.log.test. +# gnuplot mnist.gnuplot + +# The fields present in the data files that are usually proper to plot along +# the y axis are test accuracy, test loss, training loss, and learning rate. +# Those should plot along the x axis are training iterations and seconds. +# Possible combinations: +# 1. Test accuracy (test score 0) vs. training iterations / time; +# 2. Test loss (test score 1) time; +# 3. Training loss vs. training iterations / time; +# 4. Learning rate vs. training iterations / time; +# A rarer one: Training time vs. iterations. + +# What is the difference between plotting against iterations and time? +# If the overhead in one iteration is too high, one algorithm might appear +# to be faster in terms of progress per iteration and slower when measured +# against time. And the reverse case is not entirely impossible. Thus, some +# papers chose to only publish the more favorable type. It is your freedom +# to decide what to plot. + +reset +set terminal png +set output "your_chart_name.png" +set style data lines +set key right + +###### Fields in the data file your_log_name.log.train are +###### Iters Seconds TrainingLoss LearningRate + +# Training loss vs. training iterations +set title "Training loss vs. training iterations" +set xlabel "Training loss" +set ylabel "Training iterations" +plot "mnist.log.train" using 1:3 title "mnist" + +# Training loss vs. training time +# plot "mnist.log.train" using 2:3 title "mnist" + +# Learning rate vs. training iterations; +# plot "mnist.log.train" using 1:4 title "mnist" + +# Learning rate vs. training time; +# plot "mnist.log.train" using 2:4 title "mnist" + + +###### Fields in the data file your_log_name.log.test are +###### Iters Seconds TestAccuracy TestLoss + +# Test loss vs. training iterations +# plot "mnist.log.test" using 1:4 title "mnist" + +# Test accuracy vs. training iterations +# plot "mnist.log.test" using 1:3 title "mnist" + +# Test loss vs. training time +# plot "mnist.log.test" using 2:4 title "mnist" + +# Test accuracy vs. training time +# plot "mnist.log.test" using 2:3 title "mnist" diff --git a/caffe-crfrnn/tools/extra/plot_training_log.py.example b/caffe-crfrnn/tools/extra/plot_training_log.py.example new file mode 100755 index 00000000..b6fda54e --- /dev/null +++ b/caffe-crfrnn/tools/extra/plot_training_log.py.example @@ -0,0 +1,187 @@ +#!/usr/bin/env python +import inspect +import os +import random +import sys +import matplotlib.cm as cmx +import matplotlib.colors as colors +import matplotlib.pyplot as plt +import matplotlib.legend as lgd +import matplotlib.markers as mks + +def get_log_parsing_script(): + dirname = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) + return dirname + '/parse_log.sh' + +def get_log_file_suffix(): + return '.log' + +def get_chart_type_description_separator(): + return ' vs. ' + +def is_x_axis_field(field): + x_axis_fields = ['Iters', 'Seconds'] + return field in x_axis_fields + +def create_field_index(): + train_key = 'Train' + test_key = 'Test' + field_index = {train_key:{'Iters':0, 'Seconds':1, train_key + ' loss':2, + train_key + ' learning rate':3}, + test_key:{'Iters':0, 'Seconds':1, test_key + ' accuracy':2, + test_key + ' loss':3}} + fields = set() + for data_file_type in field_index.keys(): + fields = fields.union(set(field_index[data_file_type].keys())) + fields = list(fields) + fields.sort() + return field_index, fields + +def get_supported_chart_types(): + field_index, fields = create_field_index() + num_fields = len(fields) + supported_chart_types = [] + for i in xrange(num_fields): + if not is_x_axis_field(fields[i]): + for j in xrange(num_fields): + if i != j and is_x_axis_field(fields[j]): + supported_chart_types.append('%s%s%s' % ( + fields[i], get_chart_type_description_separator(), + fields[j])) + return supported_chart_types + +def get_chart_type_description(chart_type): + supported_chart_types = get_supported_chart_types() + chart_type_description = supported_chart_types[chart_type] + return chart_type_description + +def get_data_file_type(chart_type): + description = get_chart_type_description(chart_type) + data_file_type = description.split()[0] + return data_file_type + +def get_data_file(chart_type, path_to_log): + return os.path.basename(path_to_log) + '.' + get_data_file_type(chart_type).lower() + +def get_field_descriptions(chart_type): + description = get_chart_type_description(chart_type).split( + get_chart_type_description_separator()) + y_axis_field = description[0] + x_axis_field = description[1] + return x_axis_field, y_axis_field + +def get_field_indecies(x_axis_field, y_axis_field): + data_file_type = get_data_file_type(chart_type) + fields = create_field_index()[0][data_file_type] + return fields[x_axis_field], fields[y_axis_field] + +def load_data(data_file, field_idx0, field_idx1): + data = [[], []] + with open(data_file, 'r') as f: + for line in f: + line = line.strip() + if line[0] != '#': + fields = line.split() + data[0].append(float(fields[field_idx0].strip())) + data[1].append(float(fields[field_idx1].strip())) + return data + +def random_marker(): + markers = mks.MarkerStyle.markers + num = len(markers.values()) + idx = random.randint(0, num - 1) + return markers.values()[idx] + +def get_data_label(path_to_log): + label = path_to_log[path_to_log.rfind('/')+1 : path_to_log.rfind( + get_log_file_suffix())] + return label + +def get_legend_loc(chart_type): + x_axis, y_axis = get_field_descriptions(chart_type) + loc = 'lower right' + if y_axis.find('accuracy') != -1: + pass + if y_axis.find('loss') != -1 or y_axis.find('learning rate') != -1: + loc = 'upper right' + return loc + +def plot_chart(chart_type, path_to_png, path_to_log_list): + for path_to_log in path_to_log_list: + os.system('%s %s' % (get_log_parsing_script(), path_to_log)) + data_file = get_data_file(chart_type, path_to_log) + x_axis_field, y_axis_field = get_field_descriptions(chart_type) + x, y = get_field_indecies(x_axis_field, y_axis_field) + data = load_data(data_file, x, y) + ## TODO: more systematic color cycle for lines + color = [random.random(), random.random(), random.random()] + label = get_data_label(path_to_log) + linewidth = 0.75 + ## If there too many datapoints, do not use marker. +## use_marker = False + use_marker = True + if not use_marker: + plt.plot(data[0], data[1], label = label, color = color, + linewidth = linewidth) + else: + ok = False + ## Some markers throw ValueError: Unrecognized marker style + while not ok: + try: + marker = random_marker() + plt.plot(data[0], data[1], label = label, color = color, + marker = marker, linewidth = linewidth) + ok = True + except: + pass + legend_loc = get_legend_loc(chart_type) + plt.legend(loc = legend_loc, ncol = 1) # ajust ncol to fit the space + plt.title(get_chart_type_description(chart_type)) + plt.xlabel(x_axis_field) + plt.ylabel(y_axis_field) + plt.savefig(path_to_png) + plt.show() + +def print_help(): + print """This script mainly serves as the basis of your customizations. +Customization is a must. +You can copy, paste, edit them in whatever way you want. +Be warned that the fields in the training log may change in the future. +You had better check the data files and change the mapping from field name to + field index in create_field_index before designing your own plots. +Usage: + ./plot_log.sh chart_type[0-%s] /where/to/save.png /path/to/first.log ... +Notes: + 1. Supporting multiple logs. + 2. Log file name must end with the lower-cased "%s". +Supported chart types:""" % (len(get_supported_chart_types()) - 1, + get_log_file_suffix()) + supported_chart_types = get_supported_chart_types() + num = len(supported_chart_types) + for i in xrange(num): + print ' %d: %s' % (i, supported_chart_types[i]) + exit + +def is_valid_chart_type(chart_type): + return chart_type >= 0 and chart_type < len(get_supported_chart_types()) + +if __name__ == '__main__': + if len(sys.argv) < 4: + print_help() + else: + chart_type = int(sys.argv[1]) + if not is_valid_chart_type(chart_type): + print_help() + path_to_png = sys.argv[2] + if not path_to_png.endswith('.png'): + print 'Path must ends with png' % path_to_png + exit + path_to_logs = sys.argv[3:] + for path_to_log in path_to_logs: + if not os.path.exists(path_to_log): + print 'Path does not exist: %s' % path_to_log + exit + if not path_to_log.endswith(get_log_file_suffix()): + print_help() + ## plot_chart accpets multiple path_to_logs + plot_chart(chart_type, path_to_png, path_to_logs) diff --git a/caffe-crfrnn/tools/extra/resize_and_crop_images.py b/caffe-crfrnn/tools/extra/resize_and_crop_images.py new file mode 100755 index 00000000..c844f590 --- /dev/null +++ b/caffe-crfrnn/tools/extra/resize_and_crop_images.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +from mincepie import mapreducer, launcher +import gflags +import os +import cv2 +from PIL import Image + +# gflags +gflags.DEFINE_string('image_lib', 'opencv', + 'OpenCV or PIL, case insensitive. The default value is the faster OpenCV.') +gflags.DEFINE_string('input_folder', '', + 'The folder that contains all input images, organized in synsets.') +gflags.DEFINE_integer('output_side_length', 256, + 'Expected side length of the output image.') +gflags.DEFINE_string('output_folder', '', + 'The folder that we write output resized and cropped images to') +FLAGS = gflags.FLAGS + +class OpenCVResizeCrop: + def resize_and_crop_image(self, input_file, output_file, output_side_length = 256): + '''Takes an image name, resize it and crop the center square + ''' + img = cv2.imread(input_file) + height, width, depth = img.shape + new_height = output_side_length + new_width = output_side_length + if height > width: + new_height = output_side_length * height / width + else: + new_width = output_side_length * width / height + resized_img = cv2.resize(img, (new_width, new_height)) + height_offset = (new_height - output_side_length) / 2 + width_offset = (new_width - output_side_length) / 2 + cropped_img = resized_img[height_offset:height_offset + output_side_length, + width_offset:width_offset + output_side_length] + cv2.imwrite(output_file, cropped_img) + +class PILResizeCrop: +## http://united-coders.com/christian-harms/image-resizing-tips-every-coder-should-know/ + def resize_and_crop_image(self, input_file, output_file, output_side_length = 256, fit = True): + '''Downsample the image. + ''' + img = Image.open(input_file) + box = (output_side_length, output_side_length) + #preresize image with factor 2, 4, 8 and fast algorithm + factor = 1 + while img.size[0]/factor > 2*box[0] and img.size[1]*2/factor > 2*box[1]: + factor *=2 + if factor > 1: + img.thumbnail((img.size[0]/factor, img.size[1]/factor), Image.NEAREST) + + #calculate the cropping box and get the cropped part + if fit: + x1 = y1 = 0 + x2, y2 = img.size + wRatio = 1.0 * x2/box[0] + hRatio = 1.0 * y2/box[1] + if hRatio > wRatio: + y1 = int(y2/2-box[1]*wRatio/2) + y2 = int(y2/2+box[1]*wRatio/2) + else: + x1 = int(x2/2-box[0]*hRatio/2) + x2 = int(x2/2+box[0]*hRatio/2) + img = img.crop((x1,y1,x2,y2)) + + #Resize the image with best quality algorithm ANTI-ALIAS + img.thumbnail(box, Image.ANTIALIAS) + + #save it into a file-like object + with open(output_file, 'wb') as out: + img.save(out, 'JPEG', quality=75) + +class ResizeCropImagesMapper(mapreducer.BasicMapper): + '''The ImageNet Compute mapper. + The input value would be the file listing images' paths relative to input_folder. + ''' + def map(self, key, value): + if type(value) is not str: + value = str(value) + files = [value] + image_lib = FLAGS.image_lib.lower() + if image_lib == 'pil': + resize_crop = PILResizeCrop() + else: + resize_crop = OpenCVResizeCrop() + for i, line in enumerate(files): + try: + line = line.replace(FLAGS.input_folder, '').strip() + line = line.split() + image_file_name = line[0] + input_file = os.path.join(FLAGS.input_folder, image_file_name) + output_file = os.path.join(FLAGS.output_folder, image_file_name) + output_dir = output_file[:output_file.rfind('/')] + if not os.path.exists(output_dir): + os.makedirs(output_dir) + feat = resize_crop.resize_and_crop_image(input_file, output_file, + FLAGS.output_side_length) + except Exception, e: + # we ignore the exception (maybe the image is corrupted?) + print line, Exception, e + yield value, FLAGS.output_folder + +mapreducer.REGISTER_DEFAULT_MAPPER(ResizeCropImagesMapper) + +mapreducer.REGISTER_DEFAULT_READER(mapreducer.FileReader) +mapreducer.REGISTER_DEFAULT_WRITER(mapreducer.FileWriter) + +if __name__ == '__main__': + launcher.launch() diff --git a/caffe-crfrnn/tools/extract_features.cpp b/caffe-crfrnn/tools/extract_features.cpp new file mode 100644 index 00000000..ddbce107 --- /dev/null +++ b/caffe-crfrnn/tools/extract_features.cpp @@ -0,0 +1,184 @@ +#include // for snprintf +#include +#include + +#include "boost/algorithm/string.hpp" +#include "google/protobuf/text_format.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/dataset_factory.hpp" +#include "caffe/net.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/vision_layers.hpp" + +using boost::shared_ptr; +using caffe::Blob; +using caffe::Caffe; +using caffe::Dataset; +using caffe::DatasetFactory; +using caffe::Datum; +using caffe::Net; + +template +int feature_extraction_pipeline(int argc, char** argv); + +int main(int argc, char** argv) { + return feature_extraction_pipeline(argc, argv); +// return feature_extraction_pipeline(argc, argv); +} + +template +int feature_extraction_pipeline(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + const int num_required_args = 7; + if (argc < num_required_args) { + LOG(ERROR)<< + "This program takes in a trained network and an input data layer, and then" + " extract features of the input data produced by the net.\n" + "Usage: extract_features pretrained_net_param" + " feature_extraction_proto_file extract_feature_blob_name1[,name2,...]" + " save_feature_dataset_name1[,name2,...] num_mini_batches db_type" + " [CPU/GPU] [DEVICE_ID=0]\n" + "Note: you can extract multiple features in one pass by specifying" + " multiple feature blob names and dataset names seperated by ','." + " The names cannot contain white space characters and the number of blobs" + " and datasets must be equal."; + return 1; + } + int arg_pos = num_required_args; + + arg_pos = num_required_args; + if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) { + LOG(ERROR)<< "Using GPU"; + uint device_id = 0; + if (argc > arg_pos + 1) { + device_id = atoi(argv[arg_pos + 1]); + CHECK_GE(device_id, 0); + } + LOG(ERROR) << "Using Device_id=" << device_id; + Caffe::SetDevice(device_id); + Caffe::set_mode(Caffe::GPU); + } else { + LOG(ERROR) << "Using CPU"; + Caffe::set_mode(Caffe::CPU); + } + Caffe::set_phase(Caffe::TEST); + + arg_pos = 0; // the name of the executable + std::string pretrained_binary_proto(argv[++arg_pos]); + + // Expected prototxt contains at least one data layer such as + // the layer data_layer_name and one feature blob such as the + // fc7 top blob to extract features. + /* + layers { + name: "data_layer_name" + type: DATA + data_param { + source: "/path/to/your/images/to/extract/feature/images_leveldb" + mean_file: "/path/to/your/image_mean.binaryproto" + batch_size: 128 + crop_size: 227 + mirror: false + } + top: "data_blob_name" + top: "label_blob_name" + } + layers { + name: "drop7" + type: DROPOUT + dropout_param { + dropout_ratio: 0.5 + } + bottom: "fc7" + top: "fc7" + } + */ + std::string feature_extraction_proto(argv[++arg_pos]); + shared_ptr > feature_extraction_net( + new Net(feature_extraction_proto)); + feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto); + + std::string extract_feature_blob_names(argv[++arg_pos]); + std::vector blob_names; + boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(",")); + + std::string save_feature_dataset_names(argv[++arg_pos]); + std::vector dataset_names; + boost::split(dataset_names, save_feature_dataset_names, + boost::is_any_of(",")); + CHECK_EQ(blob_names.size(), dataset_names.size()) << + " the number of blob names and dataset names must be equal"; + size_t num_features = blob_names.size(); + + for (size_t i = 0; i < num_features; i++) { + CHECK(feature_extraction_net->has_blob(blob_names[i])) + << "Unknown feature blob name " << blob_names[i] + << " in the network " << feature_extraction_proto; + } + + int num_mini_batches = atoi(argv[++arg_pos]); + + std::vector > > feature_dbs; + for (size_t i = 0; i < num_features; ++i) { + LOG(INFO)<< "Opening dataset " << dataset_names[i]; + shared_ptr > dataset = + DatasetFactory(argv[++arg_pos]); + CHECK(dataset->open(dataset_names.at(i), Dataset::New)); + feature_dbs.push_back(dataset); + } + + LOG(ERROR)<< "Extacting Features"; + + Datum datum; + const int kMaxKeyStrLength = 100; + char key_str[kMaxKeyStrLength]; + std::vector*> input_vec; + std::vector image_indices(num_features, 0); + for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { + feature_extraction_net->Forward(input_vec); + for (int i = 0; i < num_features; ++i) { + const shared_ptr > feature_blob = feature_extraction_net + ->blob_by_name(blob_names[i]); + int batch_size = feature_blob->num(); + int dim_features = feature_blob->count() / batch_size; + const Dtype* feature_blob_data; + for (int n = 0; n < batch_size; ++n) { + datum.set_height(dim_features); + datum.set_width(1); + datum.set_channels(1); + datum.clear_data(); + datum.clear_float_data(); + feature_blob_data = feature_blob->cpu_data() + + feature_blob->offset(n); + for (int d = 0; d < dim_features; ++d) { + datum.add_float_data(feature_blob_data[d]); + } + int length = snprintf(key_str, kMaxKeyStrLength, "%d", + image_indices[i]); + CHECK(feature_dbs.at(i)->put(std::string(key_str, length), datum)); + ++image_indices[i]; + if (image_indices[i] % 1000 == 0) { + CHECK(feature_dbs.at(i)->commit()); + LOG(ERROR)<< "Extracted features of " << image_indices[i] << + " query images for feature blob " << blob_names[i]; + } + } // for (int n = 0; n < batch_size; ++n) + } // for (int i = 0; i < num_features; ++i) + } // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) + // write the last batch + for (int i = 0; i < num_features; ++i) { + if (image_indices[i] % 1000 != 0) { + CHECK(feature_dbs.at(i)->commit()); + } + LOG(ERROR)<< "Extracted features of " << image_indices[i] << + " query images for feature blob " << blob_names[i]; + feature_dbs.at(i)->close(); + } + + LOG(ERROR)<< "Successfully extracted the features!"; + return 0; +} + diff --git a/caffe-crfrnn/tools/finetune_net.cpp b/caffe-crfrnn/tools/finetune_net.cpp new file mode 100644 index 00000000..81c0c354 --- /dev/null +++ b/caffe-crfrnn/tools/finetune_net.cpp @@ -0,0 +1,7 @@ +#include "caffe/caffe.hpp" + +int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe train --solver=... " + "[--weights=...] instead."; + return 0; +} diff --git a/caffe-crfrnn/tools/net_speed_benchmark.cpp b/caffe-crfrnn/tools/net_speed_benchmark.cpp new file mode 100644 index 00000000..cd16e8d0 --- /dev/null +++ b/caffe-crfrnn/tools/net_speed_benchmark.cpp @@ -0,0 +1,7 @@ +#include "caffe/caffe.hpp" + +int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe time --model=... " + "[--iterations=50] [--gpu] [--device_id=0]"; + return 0; +} diff --git a/caffe-crfrnn/tools/test_net.cpp b/caffe-crfrnn/tools/test_net.cpp new file mode 100644 index 00000000..92e14eee --- /dev/null +++ b/caffe-crfrnn/tools/test_net.cpp @@ -0,0 +1,7 @@ +#include "caffe/caffe.hpp" + +int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe test --model=... " + "--weights=... instead."; + return 0; +} diff --git a/caffe-crfrnn/tools/train_net.cpp b/caffe-crfrnn/tools/train_net.cpp new file mode 100644 index 00000000..622bca31 --- /dev/null +++ b/caffe-crfrnn/tools/train_net.cpp @@ -0,0 +1,7 @@ +#include "caffe/caffe.hpp" + +int main(int argc, char** argv) { + LOG(FATAL) << "Deprecated. Use caffe train --solver=... " + "[--snapshot=...] instead."; + return 0; +} diff --git a/caffe-crfrnn/tools/upgrade_net_proto_binary.cpp b/caffe-crfrnn/tools/upgrade_net_proto_binary.cpp new file mode 100644 index 00000000..d7a62e32 --- /dev/null +++ b/caffe-crfrnn/tools/upgrade_net_proto_binary.cpp @@ -0,0 +1,44 @@ +// This is a script to upgrade "V0" network prototxts to the new format. +// Usage: +// upgrade_net_proto_binary v0_net_proto_file_in net_proto_file_out + +#include +#include // NOLINT(readability/streams) +#include // NOLINT(readability/streams) + +#include "caffe/caffe.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +using std::ofstream; + +using namespace caffe; // NOLINT(build/namespaces) + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc != 3) { + LOG(ERROR) << "Usage: " + << "upgrade_net_proto_binary v0_net_proto_file_in net_proto_file_out"; + return 1; + } + + NetParameter net_param; + if (!ReadProtoFromBinaryFile(argv[1], &net_param)) { + LOG(ERROR) << "Failed to parse input binary file as NetParameter: " + << argv[1]; + return 2; + } + bool need_upgrade = NetNeedsUpgrade(net_param); + bool success = true; + if (need_upgrade) { + NetParameter v0_net_param(net_param); + success = UpgradeV0Net(v0_net_param, &net_param); + } else { + LOG(ERROR) << "File already in V1 proto format: " << argv[1]; + } + + WriteProtoToBinaryFile(net_param, argv[2]); + + LOG(ERROR) << "Wrote upgraded NetParameter binary proto to " << argv[2]; + return !success; +} diff --git a/caffe-crfrnn/tools/upgrade_net_proto_text.cpp b/caffe-crfrnn/tools/upgrade_net_proto_text.cpp new file mode 100644 index 00000000..2f290fc5 --- /dev/null +++ b/caffe-crfrnn/tools/upgrade_net_proto_text.cpp @@ -0,0 +1,55 @@ +// This is a script to upgrade "V0" network prototxts to the new format. +// Usage: +// upgrade_net_proto_text v0_net_proto_file_in net_proto_file_out + +#include +#include // NOLINT(readability/streams) +#include // NOLINT(readability/streams) + +#include "caffe/caffe.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +using std::ofstream; + +using namespace caffe; // NOLINT(build/namespaces) + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc != 3) { + LOG(ERROR) << "Usage: " + << "upgrade_net_proto_text v0_net_proto_file_in net_proto_file_out"; + return 1; + } + + NetParameter net_param; + if (!ReadProtoFromTextFile(argv[1], &net_param)) { + LOG(ERROR) << "Failed to parse input text file as NetParameter: " + << argv[1]; + return 2; + } + bool need_upgrade = NetNeedsUpgrade(net_param); + bool need_data_upgrade = NetNeedsDataUpgrade(net_param); + bool success = true; + if (need_upgrade) { + NetParameter v0_net_param(net_param); + success = UpgradeV0Net(v0_net_param, &net_param); + } else { + LOG(ERROR) << "File already in V1 proto format: " << argv[1]; + } + + if (need_data_upgrade) { + UpgradeNetDataTransformation(&net_param); + } + + // Convert to a NetParameterPrettyPrint to print fields in desired + // order. + NetParameterPrettyPrint net_param_pretty; + NetParameterToPrettyPrint(net_param, &net_param_pretty); + + // Save new format prototxt. + WriteProtoToTextFile(net_param_pretty, argv[2]); + + LOG(ERROR) << "Wrote upgraded NetParameter text proto to " << argv[2]; + return !success; +} diff --git a/matlab-scripts/2007_000033.png b/matlab-scripts/2007_000033.png new file mode 100644 index 00000000..bbeb3f44 Binary files /dev/null and b/matlab-scripts/2007_000033.png differ diff --git a/matlab-scripts/README.md b/matlab-scripts/README.md new file mode 100644 index 00000000..002fa0db --- /dev/null +++ b/matlab-scripts/README.md @@ -0,0 +1,23 @@ +--- +name: CRF-RNN Semantic Image Segmentation Model trained on COCO-VOC +caffemodel: TVG_CRFRNN_COCO_VOC.caffemodel +caffemodel_url: http://goo.gl/j7PrPZ +license: Non-commercial, for commercial use, please contact crfasrnn@gmail.com +sha1: bfda5c5149d566aa56695789fa9a08e7a7f3070a +--- + +This model is for the ICCV paper titled "Conditional Random Fields as Recurrent Neural Networks". + +Demo website is . + +This model was trained by +Shuai Zheng @bittnt +Sadeep Jayasumana @sadeepj +Bernardino Romera-Paredes @bernard24 + +Supervisor: +Philip Torr : + +## License +This model is for the non-commercial. For other use, please contact crfasnn@gmail.com + diff --git a/matlab-scripts/TVG_CRFRNN_COCO_VOC.prototxt b/matlab-scripts/TVG_CRFRNN_COCO_VOC.prototxt new file mode 100644 index 00000000..afdab910 --- /dev/null +++ b/matlab-scripts/TVG_CRFRNN_COCO_VOC.prototxt @@ -0,0 +1,150 @@ +name: 'TVG_CRF_RNN_COCO_VOC' + +input: 'data' +input_dim: 1 +input_dim: 3 +input_dim: 500 +input_dim: 500 +force_backward: true + +layers { bottom: 'data' top: 'conv1_1' name: 'conv1_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 64 pad: 100 kernel_size: 3 } } +layers { bottom: 'conv1_1' top: 'conv1_1' name: 'relu1_1' type: RELU } +layers { bottom: 'conv1_1' top: 'conv1_2' name: 'conv1_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 64 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv1_2' top: 'conv1_2' name: 'relu1_2' type: RELU } +layers { name: 'pool1' bottom: 'conv1_2' top: 'pool1' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { name: 'conv2_1' bottom: 'pool1' top: 'conv2_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv2_1' top: 'conv2_1' name: 'relu2_1' type: RELU } +layers { bottom: 'conv2_1' top: 'conv2_2' name: 'conv2_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv2_2' top: 'conv2_2' name: 'relu2_2' type: RELU } +layers { bottom: 'conv2_2' top: 'pool2' name: 'pool2' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool2' top: 'conv3_1' name: 'conv3_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_1' top: 'conv3_1' name: 'relu3_1' type: RELU } +layers { bottom: 'conv3_1' top: 'conv3_2' name: 'conv3_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_2' top: 'conv3_2' name: 'relu3_2' type: RELU } +layers { bottom: 'conv3_2' top: 'conv3_3' name: 'conv3_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_3' top: 'conv3_3' name: 'relu3_3' type: RELU } +layers { bottom: 'conv3_3' top: 'pool3' name: 'pool3' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool3' top: 'conv4_1' name: 'conv4_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_1' top: 'conv4_1' name: 'relu4_1' type: RELU } +layers { bottom: 'conv4_1' top: 'conv4_2' name: 'conv4_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_2' top: 'conv4_2' name: 'relu4_2' type: RELU } +layers { bottom: 'conv4_2' top: 'conv4_3' name: 'conv4_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_3' top: 'conv4_3' name: 'relu4_3' type: RELU } +layers { bottom: 'conv4_3' top: 'pool4' name: 'pool4' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool4' top: 'conv5_1' name: 'conv5_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_1' top: 'conv5_1' name: 'relu5_1' type: RELU } +layers { bottom: 'conv5_1' top: 'conv5_2' name: 'conv5_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_2' top: 'conv5_2' name: 'relu5_2' type: RELU } +layers { bottom: 'conv5_2' top: 'conv5_3' name: 'conv5_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_3' top: 'conv5_3' name: 'relu5_3' type: RELU } +layers { bottom: 'conv5_3' top: 'pool5' name: 'pool5' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool5' top: 'fc6' name: 'fc6' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE kernel_size: 7 num_output: 4096 } } +layers { bottom: 'fc6' top: 'fc6' name: 'relu6' type: RELU } +layers { bottom: 'fc6' top: 'fc6' name: 'drop6' type: DROPOUT + dropout_param { dropout_ratio: 0.5 } } +layers { bottom: 'fc6' top: 'fc7' name: 'fc7' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE kernel_size: 1 num_output: 4096 } } +layers { bottom: 'fc7' top: 'fc7' name: 'relu7' type: RELU } +layers { bottom: 'fc7' top: 'fc7' name: 'drop7' type: DROPOUT + dropout_param { dropout_ratio: 0.5 } } +layers { name: 'score-fr' type: CONVOLUTION bottom: 'fc7' top: 'score' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: DECONVOLUTION name: 'score2' bottom: 'score' top: 'score2' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { kernel_size: 4 stride: 2 num_output: 21 } } + +layers { name: 'score-pool4' type: CONVOLUTION bottom: 'pool4' top: 'score-pool4' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: CROP name: 'crop' bottom: 'score-pool4' bottom: 'score2' + top: 'score-pool4c' } + +layers { type: ELTWISE name: 'fuse' bottom: 'score2' bottom: 'score-pool4c' + top: 'score-fused' + eltwise_param { operation: SUM } } + +layers { type: DECONVOLUTION name: 'score4' bottom: 'score-fused' + top: 'score4' + blobs_lr: 1 weight_decay: 1 + convolution_param { bias_term: false kernel_size: 4 stride: 2 num_output: 21 } } + +layers { name: 'score-pool3' type: CONVOLUTION bottom: 'pool3' top: 'score-pool3' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: CROP name: 'crop' bottom: 'score-pool3' bottom: 'score4' + top: 'score-pool3c' } + +layers { type: ELTWISE name: 'fuse' bottom: 'score4' bottom: 'score-pool3c' + top: 'score-final' + eltwise_param { operation: SUM } } + +layers { type: DECONVOLUTION name: 'upsample' + bottom: 'score-final' top: 'bigscore' + blobs_lr: 0 + convolution_param { bias_term: false num_output: 21 kernel_size: 16 stride: 8 } } + +layers { type: CROP name: 'crop' bottom: 'bigscore' bottom: 'data' top: 'coarse' } + +layers { type: SPLIT name: 'splitting' + bottom: 'coarse' top: 'unary' top: 'Q0' +} + +layers { + name: "inference1" + type: MULTI_STAGE_MEANFIELD + bottom: "unary" + bottom: "Q0" + bottom: "data" + top: "pred" + blobs_lr: 0.001 + blobs_lr: 0.001 + blobs_lr:0.01 #new parameter + multi_stage_meanfield_param { + num_iterations: 10 + compatibility_mode: POTTS + threshold: 2 + theta_alpha: 160 + theta_beta: 3 + theta_gamma: 3 + spatial_filter_weight: 3 + bilateral_filter_weight: 5 + } +} diff --git a/matlab-scripts/bilateral.par b/matlab-scripts/bilateral.par new file mode 100644 index 00000000..69130256 --- /dev/null +++ b/matlab-scripts/bilateral.par @@ -0,0 +1 @@ +5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 diff --git a/matlab-scripts/crfrnn_demo.m b/matlab-scripts/crfrnn_demo.m new file mode 100644 index 00000000..53c73795 --- /dev/null +++ b/matlab-scripts/crfrnn_demo.m @@ -0,0 +1,50 @@ +clear all; +close all; +%This is a software bundle "CRF-RNN", which is published in a ICCV paper titled "Conditional Random Fields as Recurrent Neural Networks". This is implemented as part of the Caffe library, written in C++/C. The current version is maintained by: +% +%Shuai Zheng : szheng@robots.ox.ac.uk Sadeep Jayasumana : sadeep@robots.ox.ac.uk Bernardino Romera Paredes : +% +%Supervisor: Philip Torr : philip.torr@eng.ox.ac.uk +% +%For more information about CRF-RNN please vist the project website http://crfasrnn.torr.vision. +% +caffe_path = '../caffe-crfrnn/'; + +model_def_file = 'TVG_CRFRNN_COCO_VOC.prototxt'; +model_file = 'TVG_CRFRNN_COCO_VOC.caffemodel'; + +use_gpu = 1; + +addpath(fullfile(caffe_path, 'matlab/caffe')); + +caffe('reset'); +caffe('set_device', 0);% change here if you have a powerful GPU in different device, nvidia-smi will help you check the device information. + +tvg_matcaffe_init(use_gpu, model_def_file, model_file); +[~, map] = imread('2007_000033.png']); + +im = imread('input.jpg'); + +[h, w, d] = size(im); + +if (d ~= 3) + error('Error! Wrong depth.\n'); +end + +if (h > 500 || w > 500) + error('Error! Wrong image size.\n'); +end + +prepared_im = tvg_prepare_image_fixed(im); + +inputData = {prepared_im}; +scores = caffe('forward', inputData); + +Q = scores{1}; + +[dumb, pred] = max(Q, [], 3); +pred = pred'; +pred = pred(1:h, 1:w); + +imwrite(pred, map, ['output.png'], 'png'); + diff --git a/matlab-scripts/devtools/tvg_VOCevalseg.m b/matlab-scripts/devtools/tvg_VOCevalseg.m new file mode 100644 index 00000000..2ef7f625 --- /dev/null +++ b/matlab-scripts/devtools/tvg_VOCevalseg.m @@ -0,0 +1,92 @@ +%VOCEVALSEG Evaluates a set of segmentation results. +% VOCEVALSEG(VOCopts,ID); prints out the per class and overall +% segmentation accuracies. Accuracies are given using the intersection/union +% metric: +% true positives / (true positives + false positives + false negatives) +% +% [ACCURACIES,AVACC,CONF] = VOCEVALSEG(VOCopts,ID) returns the per class +% percentage ACCURACIES, the average accuracy AVACC and the confusion +% matrix CONF. +% +% [ACCURACIES,AVACC,CONF,RAWCOUNTS] = VOCEVALSEG(VOCopts,ID) also returns +% the unnormalised confusion matrix, which contains raw pixel counts. +function [accuracies,avacc,conf,rawcounts] = tvg_VOCevalseg(VOCopts,id) + +% image test set +[gtids,t]=textread(sprintf(VOCopts.seg.imgsetpath,VOCopts.testset),'%s %d'); + +% number of labels = number of classes plus one for the background +num = VOCopts.nclasses+1; +confcounts = zeros(num); +count=0; +tic; +for i=1:length(gtids) + % display progress + if toc>1 + fprintf('test confusion: %d/%d\n',i,length(gtids)); + drawnow; + tic; + end + + imname = gtids{i}; + + % ground truth label file + gtfile = sprintf(VOCopts.seg.clsimgpath,imname); + [gtim,map] = imread(gtfile); + gtim = double(gtim); + + % results file + resfile = sprintf(VOCopts.seg.clsrespath,id,VOCopts.testset,imname); + [resim,map] = imread(resfile); + resim = double(resim); + + % Check validity of results image + maxlabel = max(resim(:)); + if (maxlabel>VOCopts.nclasses), + error('Results image ''%s'' has out of range value %d (the value should be <= %d)',imname,maxlabel,VOCopts.nclasses); + end + + szgtim = size(gtim); szresim = size(resim); + if any(szgtim~=szresim) + error('Results image ''%s'' is the wrong size, was %d x %d, should be %d x %d.',imname,szresim(1),szresim(2),szgtim(1),szgtim(2)); + end + + %pixel locations to include in computation + locs = gtim<255; + + % joint histogram + sumim = 1+gtim+resim*num; + hs = histc(sumim(locs),1:num*num); + count = count + numel(find(locs)); + confcounts(:) = confcounts(:) + hs(:); +end + +% confusion matrix - first index is true label, second is inferred label +%conf = zeros(num); +conf = 100*confcounts./repmat(1E-20+sum(confcounts,2),[1 size(confcounts,2)]); +rawcounts = confcounts; + +% Percentage correct labels measure is no longer being used. Uncomment if +% you wish to see it anyway +%overall_acc = 100*sum(diag(confcounts)) / sum(confcounts(:)); +%fprintf('Percentage of pixels correctly labelled overall: %6.3f%%\n',overall_acc); + +accuracies = zeros(VOCopts.nclasses,1); +fprintf('Accuracy for each class (intersection/union measure)\n'); +for j=1:num + + gtj=sum(confcounts(j,:)); + resj=sum(confcounts(:,j)); + gtjresj=confcounts(j,j); + % The accuracy is: true positive / (true positive + false positive + false negative) + % which is equivalent to the following percentage: + accuracies(j)=100*gtjresj/(gtj+resj-gtjresj); + + clname = 'background'; + if (j>1), clname = VOCopts.classes{j-1};end; + fprintf(' %14s: %6.3f%%\n',clname,accuracies(j)); +end +accuracies = accuracies(1:end); +avacc = mean(accuracies); +fprintf('-------------------------\n'); +fprintf('Average accuracy: %6.3f%%\n',avacc); diff --git a/matlab-scripts/devtools/tvg_VOCinit.m b/matlab-scripts/devtools/tvg_VOCinit.m new file mode 100644 index 00000000..dce8c059 --- /dev/null +++ b/matlab-scripts/devtools/tvg_VOCinit.m @@ -0,0 +1,142 @@ +clear VOCopts + +% dataset +% +% Note for experienced users: the VOC2008-10 test sets are subsets +% of the VOC2010 test set. You don't need to do anything special +% to submit results for VOC2008-10. + +VOCopts.dataset='VOC2012'; + +% get devkit directory with forward slashes +devkitroot=strrep(fileparts(fileparts(mfilename('fullpath'))),'\','/'); + +% change this path to point to your copy of the PASCAL VOC data +VOCopts.datadir=[devkitroot '/']; + +% change this path to a writable directory for your results +VOCopts.resdir=[devkitroot '/results/' VOCopts.dataset '/']; + +% change this path to a writable local directory for the example code +VOCopts.localdir=[devkitroot '/local/' VOCopts.dataset '/']; + +% initialize the training set + +VOCopts.trainset='train'; % use train for development +% VOCopts.trainset='trainval'; % use train+val for final challenge + +% initialize the test set + +VOCopts.testset='val'; % use validation data for development test set +% VOCopts.testset='test'; % use test set for final challenge + +% initialize main challenge paths + +VOCopts.annopath=[VOCopts.datadir VOCopts.dataset '/Annotations/%s.xml']; +VOCopts.imgpath=[VOCopts.datadir VOCopts.dataset '/JPEGImages/%s.jpg']; +VOCopts.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Main/%s.txt']; +VOCopts.clsimgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Main/%s_%s.txt']; +VOCopts.clsrespath=[VOCopts.resdir 'Main/%s_cls_' VOCopts.testset '_%s.txt']; +VOCopts.detrespath=[VOCopts.resdir 'Main/%s_det_' VOCopts.testset '_%s.txt']; + +% initialize segmentation task paths + +VOCopts.seg.clsimgpath=[VOCopts.datadir VOCopts.dataset '/SegmentationClass/%s.png']; +VOCopts.seg.instimgpath=[VOCopts.datadir VOCopts.dataset '/SegmentationObject/%s.png']; + +VOCopts.seg.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Segmentation/%s.txt']; + +VOCopts.seg.clsresdir=[VOCopts.resdir 'Segmentation/%s_%s_cls']; +VOCopts.seg.instresdir=[VOCopts.resdir 'Segmentation/%s_%s_inst']; +VOCopts.seg.clsrespath=[VOCopts.seg.clsresdir '/%s.png']; +VOCopts.seg.instrespath=[VOCopts.seg.instresdir '/%s.png']; + +% initialize layout task paths + +VOCopts.layout.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Layout/%s.txt']; +VOCopts.layout.respath=[VOCopts.resdir 'Layout/%s_layout_' VOCopts.testset '.xml']; + +% initialize action task paths + +VOCopts.action.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Action/%s.txt']; +VOCopts.action.clsimgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Action/%s_%s.txt']; +VOCopts.action.respath=[VOCopts.resdir 'Action/%s_action_' VOCopts.testset '_%s.txt']; + +% initialize the VOC challenge options + +% classes + +VOCopts.classes={... + 'aeroplane' + 'bicycle' + 'bird' + 'boat' + 'bottle' + 'bus' + 'car' + 'cat' + 'chair' + 'cow' + 'diningtable' + 'dog' + 'horse' + 'motorbike' + 'person' + 'pottedplant' + 'sheep' + 'sofa' + 'train' + 'tvmonitor'}; + +VOCopts.nclasses=length(VOCopts.classes); + +% poses + +VOCopts.poses={... + 'Unspecified' + 'Left' + 'Right' + 'Frontal' + 'Rear'}; + +VOCopts.nposes=length(VOCopts.poses); + +% layout parts + +VOCopts.parts={... + 'head' + 'hand' + 'foot'}; + +VOCopts.nparts=length(VOCopts.parts); + +VOCopts.maxparts=[1 2 2]; % max of each of above parts + +% actions + +VOCopts.actions={... + 'other' % skip this when training classifiers + 'jumping' % new in VOC2011 + 'phoning' + 'playinginstrument' + 'reading' + 'ridingbike' + 'ridinghorse' + 'running' + 'takingphoto' + 'usingcomputer' + 'walking'}; + +VOCopts.nactions=length(VOCopts.actions); + +% overlap threshold + +VOCopts.minoverlap=0.5; + +% annotation cache for evaluation + +VOCopts.annocachepath=[VOCopts.localdir '%s_anno.mat']; + +% options for example implementations + +VOCopts.exfdpath=[VOCopts.localdir '%s_fd.mat']; diff --git a/matlab-scripts/download_trained_model.sh b/matlab-scripts/download_trained_model.sh new file mode 100755 index 00000000..00d3427b --- /dev/null +++ b/matlab-scripts/download_trained_model.sh @@ -0,0 +1 @@ +wget http://goo.gl/j7PrPZ diff --git a/matlab-scripts/input.jpg b/matlab-scripts/input.jpg new file mode 100644 index 00000000..73624779 Binary files /dev/null and b/matlab-scripts/input.jpg differ diff --git a/matlab-scripts/spatial.par b/matlab-scripts/spatial.par new file mode 100644 index 00000000..737f5c09 --- /dev/null +++ b/matlab-scripts/spatial.par @@ -0,0 +1 @@ +3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 diff --git a/python-scripts/README.md b/python-scripts/README.md new file mode 100644 index 00000000..002fa0db --- /dev/null +++ b/python-scripts/README.md @@ -0,0 +1,23 @@ +--- +name: CRF-RNN Semantic Image Segmentation Model trained on COCO-VOC +caffemodel: TVG_CRFRNN_COCO_VOC.caffemodel +caffemodel_url: http://goo.gl/j7PrPZ +license: Non-commercial, for commercial use, please contact crfasrnn@gmail.com +sha1: bfda5c5149d566aa56695789fa9a08e7a7f3070a +--- + +This model is for the ICCV paper titled "Conditional Random Fields as Recurrent Neural Networks". + +Demo website is . + +This model was trained by +Shuai Zheng @bittnt +Sadeep Jayasumana @sadeepj +Bernardino Romera-Paredes @bernard24 + +Supervisor: +Philip Torr : + +## License +This model is for the non-commercial. For other use, please contact crfasnn@gmail.com + diff --git a/python-scripts/TVG_CRFRNN_COCO_VOC.prototxt b/python-scripts/TVG_CRFRNN_COCO_VOC.prototxt new file mode 100644 index 00000000..afdab910 --- /dev/null +++ b/python-scripts/TVG_CRFRNN_COCO_VOC.prototxt @@ -0,0 +1,150 @@ +name: 'TVG_CRF_RNN_COCO_VOC' + +input: 'data' +input_dim: 1 +input_dim: 3 +input_dim: 500 +input_dim: 500 +force_backward: true + +layers { bottom: 'data' top: 'conv1_1' name: 'conv1_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 64 pad: 100 kernel_size: 3 } } +layers { bottom: 'conv1_1' top: 'conv1_1' name: 'relu1_1' type: RELU } +layers { bottom: 'conv1_1' top: 'conv1_2' name: 'conv1_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 64 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv1_2' top: 'conv1_2' name: 'relu1_2' type: RELU } +layers { name: 'pool1' bottom: 'conv1_2' top: 'pool1' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { name: 'conv2_1' bottom: 'pool1' top: 'conv2_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv2_1' top: 'conv2_1' name: 'relu2_1' type: RELU } +layers { bottom: 'conv2_1' top: 'conv2_2' name: 'conv2_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 128 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv2_2' top: 'conv2_2' name: 'relu2_2' type: RELU } +layers { bottom: 'conv2_2' top: 'pool2' name: 'pool2' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool2' top: 'conv3_1' name: 'conv3_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_1' top: 'conv3_1' name: 'relu3_1' type: RELU } +layers { bottom: 'conv3_1' top: 'conv3_2' name: 'conv3_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_2' top: 'conv3_2' name: 'relu3_2' type: RELU } +layers { bottom: 'conv3_2' top: 'conv3_3' name: 'conv3_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 256 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv3_3' top: 'conv3_3' name: 'relu3_3' type: RELU } +layers { bottom: 'conv3_3' top: 'pool3' name: 'pool3' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool3' top: 'conv4_1' name: 'conv4_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_1' top: 'conv4_1' name: 'relu4_1' type: RELU } +layers { bottom: 'conv4_1' top: 'conv4_2' name: 'conv4_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_2' top: 'conv4_2' name: 'relu4_2' type: RELU } +layers { bottom: 'conv4_2' top: 'conv4_3' name: 'conv4_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv4_3' top: 'conv4_3' name: 'relu4_3' type: RELU } +layers { bottom: 'conv4_3' top: 'pool4' name: 'pool4' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool4' top: 'conv5_1' name: 'conv5_1' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_1' top: 'conv5_1' name: 'relu5_1' type: RELU } +layers { bottom: 'conv5_1' top: 'conv5_2' name: 'conv5_2' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_2' top: 'conv5_2' name: 'relu5_2' type: RELU } +layers { bottom: 'conv5_2' top: 'conv5_3' name: 'conv5_3' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 512 pad: 1 kernel_size: 3 } } +layers { bottom: 'conv5_3' top: 'conv5_3' name: 'relu5_3' type: RELU } +layers { bottom: 'conv5_3' top: 'pool5' name: 'pool5' type: POOLING + pooling_param { pool: MAX kernel_size: 2 stride: 2 } } +layers { bottom: 'pool5' top: 'fc6' name: 'fc6' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE kernel_size: 7 num_output: 4096 } } +layers { bottom: 'fc6' top: 'fc6' name: 'relu6' type: RELU } +layers { bottom: 'fc6' top: 'fc6' name: 'drop6' type: DROPOUT + dropout_param { dropout_ratio: 0.5 } } +layers { bottom: 'fc6' top: 'fc7' name: 'fc7' type: CONVOLUTION + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE kernel_size: 1 num_output: 4096 } } +layers { bottom: 'fc7' top: 'fc7' name: 'relu7' type: RELU } +layers { bottom: 'fc7' top: 'fc7' name: 'drop7' type: DROPOUT + dropout_param { dropout_ratio: 0.5 } } +layers { name: 'score-fr' type: CONVOLUTION bottom: 'fc7' top: 'score' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: DECONVOLUTION name: 'score2' bottom: 'score' top: 'score2' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { kernel_size: 4 stride: 2 num_output: 21 } } + +layers { name: 'score-pool4' type: CONVOLUTION bottom: 'pool4' top: 'score-pool4' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: CROP name: 'crop' bottom: 'score-pool4' bottom: 'score2' + top: 'score-pool4c' } + +layers { type: ELTWISE name: 'fuse' bottom: 'score2' bottom: 'score-pool4c' + top: 'score-fused' + eltwise_param { operation: SUM } } + +layers { type: DECONVOLUTION name: 'score4' bottom: 'score-fused' + top: 'score4' + blobs_lr: 1 weight_decay: 1 + convolution_param { bias_term: false kernel_size: 4 stride: 2 num_output: 21 } } + +layers { name: 'score-pool3' type: CONVOLUTION bottom: 'pool3' top: 'score-pool3' + blobs_lr: 1 blobs_lr: 2 weight_decay: 1 weight_decay: 0 + convolution_param { engine: CAFFE num_output: 21 kernel_size: 1 } } + +layers { type: CROP name: 'crop' bottom: 'score-pool3' bottom: 'score4' + top: 'score-pool3c' } + +layers { type: ELTWISE name: 'fuse' bottom: 'score4' bottom: 'score-pool3c' + top: 'score-final' + eltwise_param { operation: SUM } } + +layers { type: DECONVOLUTION name: 'upsample' + bottom: 'score-final' top: 'bigscore' + blobs_lr: 0 + convolution_param { bias_term: false num_output: 21 kernel_size: 16 stride: 8 } } + +layers { type: CROP name: 'crop' bottom: 'bigscore' bottom: 'data' top: 'coarse' } + +layers { type: SPLIT name: 'splitting' + bottom: 'coarse' top: 'unary' top: 'Q0' +} + +layers { + name: "inference1" + type: MULTI_STAGE_MEANFIELD + bottom: "unary" + bottom: "Q0" + bottom: "data" + top: "pred" + blobs_lr: 0.001 + blobs_lr: 0.001 + blobs_lr:0.01 #new parameter + multi_stage_meanfield_param { + num_iterations: 10 + compatibility_mode: POTTS + threshold: 2 + theta_alpha: 160 + theta_beta: 3 + theta_gamma: 3 + spatial_filter_weight: 3 + bilateral_filter_weight: 5 + } +} diff --git a/python-scripts/bilateral.par b/python-scripts/bilateral.par new file mode 100644 index 00000000..69130256 --- /dev/null +++ b/python-scripts/bilateral.par @@ -0,0 +1 @@ +5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 diff --git a/python-scripts/crfasrnn_demo.py b/python-scripts/crfasrnn_demo.py new file mode 100644 index 00000000..e225f9b8 --- /dev/null +++ b/python-scripts/crfasrnn_demo.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" +This is a software bundle "CRF-RNN", which is published in a ICCV paper titled "Conditional Random Fields as Recurrent Neural Networks". This is implemented as part of the Caffe library, written in C++/C. The current version is maintained by: + +Shuai Zheng : szheng@robots.ox.ac.uk Sadeep Jayasumana : sadeep@robots.ox.ac.uk Bernardino Romera Paredes : + +Supervisor: Philip Torr : philip.torr@eng.ox.ac.uk + +For more information about CRF-RNN please vist the project website http://crfasrnn.torr.vision. +""" + +caffe_root = '../caffe-crfrnn/' +import sys +sys.path.insert(0,caffe_root+'python') + +import os +import cPickle +import logging +import numpy as np +import pandas as pd +from PIL import Image as PILImage +#import Image +import cStringIO as StringIO +import caffe +import matplotlib.pyplot as plt + + +MODEL_FILE = 'TVG_CRFRNN_COCO_VOC.prototxt' +PRETRAINED = 'TVG_CRFRNN_COCO_VOC.caffemodel' +IMAGE_FILE = '2007_000032.jpg' + + +#caffe.set_mode_gpu() +net = caffe.Segmenter(MODEL_FILE, PRETRAINED) +input_image = 255*caffe.io.load_image(IMAGE_FILE) + + +width = input_image.shape[0] +height = input_image.shape[1] +maxDim = max(width,height) + +image = PILImage.fromarray(np.uint8(input_image)) +image = np.array(image) + +pallete = [0,0,0, + 128,0,0, + 0,128,0, + 128,128,0, + 0,0,128, + 128,0,128, + 0,128,128, + 128,128,128, + 64,0,0, + 192,0,0, + 64,128,0, + 192,128,0, + 64,0,128, + 192,0,128, + 64,128,128, + 192,128,128, + 0,64,0, + 128,64,0, + 0,192,0, + 128,192,0, + 0,64,128, + 128,64,128, + 0,192,128, + 128,192,128, + 64,64,0, + 192,64,0, + 64,192,0, + 192,192,0] + +mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32) +reshaped_mean_vec = mean_vec.reshape(1,1,3); + +# Rearrange channels to form BGR +im = image[:,:,::-1] +# Subtract mean +im = im - reshaped_mean_vec + +# Pad as necessary +cur_h, cur_w, cur_c = im.shape +pad_h = 500 - cur_h +pad_w = 500 - cur_w +im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode = 'constant', constant_values = 0) +# Get predictions +segmentation = net.predict([im]) +segmentation2 = segmentation[0:cur_h,0:cur_w] +output_im = PILImage.fromarray(segmentation2) +output_im.putpalette(pallete) + + +plt.imshow(output_im) +plt.savefig('output.png') diff --git a/python-scripts/download_trained_model.sh b/python-scripts/download_trained_model.sh new file mode 100755 index 00000000..92ef9e39 --- /dev/null +++ b/python-scripts/download_trained_model.sh @@ -0,0 +1 @@ +wget http://goo.gl/j7PrPZ TVG_CRFRNN_COCO_VOC.caffemodel diff --git a/python-scripts/input.jpg b/python-scripts/input.jpg new file mode 100644 index 00000000..73624779 Binary files /dev/null and b/python-scripts/input.jpg differ diff --git a/python-scripts/spatial.par b/python-scripts/spatial.par new file mode 100644 index 00000000..737f5c09 --- /dev/null +++ b/python-scripts/spatial.par @@ -0,0 +1 @@ +3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 diff --git a/sample.png b/sample.png new file mode 100644 index 00000000..ee8fd61f Binary files /dev/null and b/sample.png differ