Skip to content

Commit

Permalink
Merge pull request #53 from ewanwm/feature_python_interface
Browse files Browse the repository at this point in the history
Feature python interface
  • Loading branch information
ewanwm authored Sep 19, 2024
2 parents e15abe2 + 4993f4e commit 8b8c1cb
Show file tree
Hide file tree
Showing 16 changed files with 577 additions and 138 deletions.
30 changes: 30 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ set(CMAKE_CXX_STANDARD 17)

project(nuTens)

# Changes default install path to be a subdirectory of the build dir.
# Can set build dir at configure time with -DCMAKE_INSTALL_PREFIX=/install/path
if(CMAKE_INSTALL_PREFIX STREQUAL "" OR CMAKE_INSTALL_PREFIX STREQUAL
"/usr/local")
set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}")
elseif(NOT DEFINED CMAKE_INSTALL_PREFIX)
set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}")
endif()


# Need to add some special compile flags to check the code test coverage
OPTION(NT_TEST_COVERAGE "produce code coverage reports when running tests" OFF)
IF(NT_TEST_COVERAGE)
Expand Down Expand Up @@ -46,6 +56,19 @@ ELSE()
message("Won't benchmark")
ENDIF()

# If user wants to enable python interface we need to include pybind
OPTION(NT_ENABLE_PYTHON "enable python interface" OFF)
IF(NT_ENABLE_PYTHON)
message("Enabling python")
CPMAddPackage(
GITHUB_REPOSITORY "pybind/pybind11"
VERSION 2.13.5
)

ELSE()
message("Won't enable python interface")
ENDIF()


## check build times
## have this optional as it's not supported on all CMake platforms
Expand All @@ -66,6 +89,13 @@ IF(NT_ENABLE_BENCHMARKING)
add_subdirectory(benchmarks)
ENDIF()

IF(NT_ENABLE_PYTHON)
set_property( TARGET tensor PROPERTY POSITION_INDEPENDENT_CODE ON )
set_property( TARGET propagator PROPERTY POSITION_INDEPENDENT_CODE ON )
set_property( TARGET spdlog PROPERTY POSITION_INDEPENDENT_CODE ON )
set_property( TARGET logging PROPERTY POSITION_INDEPENDENT_CODE ON )
add_subdirectory(python)
ENDIF()

# Print out a handy message to more easily see the config options
message( STATUS "The following variables have been used to configure the build: " )
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ nuTens uses [Googles benchmark library](https://github.com/google/benchmark) to
- [x] Add suite of benchmarking tests
- [x] Integrate benchmarks into CI ( maybe use [hyperfine](https://github.com/sharkdp/hyperfine) and [bencher](https://bencher.dev/) for this? )
- [ ] Add proper unit tests
- [ ] Expand CI to include more platforms
- [x] Expand CI to include more platforms
- [ ] Add support for modules (see [PyTorch doc](https://pytorch.org/cppdocs/api/classtorch_1_1nn_1_1_module.html))
- [ ] Propagation in variable matter density
- [ ] Add support for Tensorflow backend
- [ ] Add python interface
- [x] Add python interface

2 changes: 1 addition & 1 deletion benchmarks/benchmarks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static void BM_constMatterOscillations(benchmark::State &state)

// set up the propagator
Propagator matterProp(3, 100.0);
std::unique_ptr<BaseMatterSolver> matterSolver = std::make_unique<ConstDensityMatterSolver>(3, 2.6);
std::shared_ptr<BaseMatterSolver> matterSolver = std::make_shared<ConstDensityMatterSolver>(3, 2.6);
matterProp.setPMNS(PMNS);
matterProp.setMasses(masses);
matterProp.setMatterSolver(matterSolver);
Expand Down
1 change: 1 addition & 0 deletions nuTens/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_library(logging logging.hpp)
target_link_libraries(logging spdlog::spdlog)
set_target_properties(logging PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(logging PUBLIC "${SPDLOG_INCLUDE_DIRS}")

set( NT_LOG_LEVEL "INFO" CACHE STRING "the level of detail to log to the console" )

Expand Down
4 changes: 2 additions & 2 deletions nuTens/propagator/const-density-solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ void ConstDensityMatterSolver::calculateEigenvalues(const Tensor &energies, Tens
for (int j = 0; j < nGenerations; j++)
{
hamiltonian.setValue({"...", i, j},
Tensor::div(diagMassMatrix.getValue({i, j}), energies.getValue({"...", 0})) -
electronOuter.getValue({i, j}));
Tensor::div(diagMassMatrix.getValues({i, j}), energies.getValues({"...", 0})) -
electronOuter.getValues({i, j}));
}
}

Expand Down
7 changes: 4 additions & 3 deletions nuTens/propagator/const-density-solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ class ConstDensityMatterSolver : public BaseMatterSolver

// construct the outer product of the electron neutrino row of the PMNS
// matrix used to construct the hamiltonian
electronOuter = Tensor::scale(Tensor::outer(PMNS.getValue({0, 0, "..."}), PMNS.getValue({0, 0, "..."}).conj()),
Constants::Groot2 * density);
electronOuter =
Tensor::scale(Tensor::outer(PMNS.getValues({0, 0, "..."}), PMNS.getValues({0, 0, "..."}).conj()),
Constants::Groot2 * density);
};

/// @brief Set new mass eigenvalues for this solver
Expand All @@ -64,7 +65,7 @@ class ConstDensityMatterSolver : public BaseMatterSolver

masses = newMasses;

Tensor m = masses.getValue({0, "..."});
Tensor m = masses.getValues({0, "..."});
Tensor diag = Tensor::scale(Tensor::mul(m, m), 0.5);

// construct the diagonal mass^2 matrix used in the hamiltonian
Expand Down
2 changes: 1 addition & 1 deletion nuTens/propagator/propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Tensor Propagator::_calculateProbs(const Tensor &energies, const Tensor &massesS
{
for (int j = 0; j < _nGenerations; j++)
{
weightMatrix.setValue({"...", i, j}, weightVector.getValue({"...", j}));
weightMatrix.setValue({"...", i, j}, weightVector.getValues({"...", j}));
}
}
weightMatrix.requiresGrad(true);
Expand Down
4 changes: 2 additions & 2 deletions nuTens/propagator/propagator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Propagator

/// @brief Set a matter solver to use to deal with matter effects
/// @param newSolver A derivative of BaseMatterSolver
inline void setMatterSolver(std::unique_ptr<BaseMatterSolver> &newSolver)
inline void setMatterSolver(std::shared_ptr<BaseMatterSolver> &newSolver)
{
NT_PROFILE();
_matterSolver = std::move(newSolver);
Expand Down Expand Up @@ -113,5 +113,5 @@ class Propagator
int _nGenerations;
float _baseline;

std::unique_ptr<BaseMatterSolver> _matterSolver;
std::shared_ptr<BaseMatterSolver> _matterSolver;
};
2 changes: 2 additions & 0 deletions nuTens/tensors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ IF(NT_USE_PCH)
target_link_libraries(tensor PUBLIC nuTens-pch)

ELSE()
target_link_libraries(tensor PUBLIC logging)

IF(TORCH_FOUND)

target_link_libraries(tensor PUBLIC "${TORCH_LIBRARIES}")
Expand Down
27 changes: 26 additions & 1 deletion nuTens/tensors/dtypes.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

#if USE_PYTORCH
#include <torch/torch.h>
#endif

/*!
* @file dtypes.hpp
* @brief Defines various datatypes used in the project
Expand All @@ -16,13 +20,34 @@ enum scalarType
kDouble,
kComplexFloat,
kComplexDouble,
kUninitScalar,
};

/// Devices that a Tensor can live on
enum deviceType
{
kCPU,
kGPU
kGPU,
kUninitDevice,
};

#if USE_PYTORCH
/// map between the data types used in nuTens and those used by pytorch
const static std::map<scalarType, c10::ScalarType> scalarTypeMap = {{kFloat, torch::kFloat},
{kDouble, torch::kDouble},
{kComplexFloat, torch::kComplexFloat},
{kComplexDouble, torch::kComplexDouble}};

/// inverse map between the data types used in nuTens and those used by pytorch
const static std::map<c10::ScalarType, scalarType> invScalarTypeMap = {{torch::kFloat, kFloat},
{torch::kDouble, kDouble},
{torch::kComplexFloat, kComplexFloat},
{torch::kComplexDouble, kComplexDouble}};

// map between the device types used in nuTens and those used by pytorch
const static std::map<deviceType, c10::DeviceType> deviceTypeMap = {{kCPU, torch::kCPU}, {kGPU, torch::kCUDA}};

// inverse map between the device types used in nuTens and those used by pytorch
const static std::map<c10::DeviceType, deviceType> invDeviceTypeMap = {{torch::kCPU, kCPU}, {torch::kCUDA, kGPU}};
#endif
} // namespace NTdtypes
Loading

0 comments on commit 8b8c1cb

Please sign in to comment.