Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuning] Interface for Fin #3330

Draft
wants to merge 34 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bf1cc6d
Add interface for Fin
averinevg Oct 23, 2024
ec2ef18
Fix formatting
averinevg Oct 23, 2024
2b89653
Merge branch 'develop' into ea_interface_for_fin
averinevg Oct 23, 2024
be07add
Fix the description
averinevg Oct 25, 2024
655bb36
Merge branch 'develop' into ea_interface_for_fin
averinevg Oct 25, 2024
3652303
Fix formatting
averinevg Oct 25, 2024
d89c0ce
Fix tidy
averinevg Oct 25, 2024
ee46f04
Fix windows build
averinevg Oct 25, 2024
a56b391
Refactor
averinevg Oct 25, 2024
3a89c66
Fix formatting
averinevg Oct 25, 2024
d0d4463
Merge branch 'develop' into ea_interface_for_fin
averinevg Oct 25, 2024
ee2196f
Fix tidy
averinevg Oct 25, 2024
2a9b9a7
Fix windows build
averinevg Oct 25, 2024
c9abd64
Merge branch 'develop' into ea_interface_for_fin
averinevg Oct 28, 2024
7b47b41
Fix tidy
averinevg Oct 28, 2024
80261ed
Fix windows build
averinevg Oct 28, 2024
a8a432e
Refactor, add GetConvSolvers & GetBatchNormSolvers
averinevg Oct 28, 2024
f57bbac
Fix formatting
averinevg Oct 28, 2024
81ffac5
Fix tidy
averinevg Oct 29, 2024
230acd5
Add test for GetSolvers()
averinevg Oct 29, 2024
0f9bdf2
Refactor
averinevg Oct 29, 2024
61bfaea
Fix linker error
averinevg Oct 29, 2024
8fa4325
Fix formatting
averinevg Oct 29, 2024
da58038
Merge branch 'develop' into ea_interface_for_fin
averinevg Oct 30, 2024
c7525a9
Add test for IsApplicable() & GetWorkspaceSize()
averinevg Oct 30, 2024
6da9a98
Fix formatting
averinevg Oct 30, 2024
9a3591e
Fix windows build
averinevg Oct 31, 2024
41afab4
Merge branch 'develop' into ea_interface_for_fin
averinevg Nov 1, 2024
279d459
Implement methods
averinevg Nov 13, 2024
0ff5960
Fix formatting
averinevg Nov 13, 2024
8fffbc7
Merge branch 'develop' into ea_interface_for_fin
averinevg Nov 13, 2024
9109326
Merge branch 'develop' into ea_interface_for_fin
averinevg Nov 14, 2024
f67b5b5
Add tests for GetAllSolutions(), GetPerfCfgParams() & TestPerfCfgPara…
averinevg Nov 14, 2024
b315df9
Registry: remove smart pointer
averinevg Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ set( MIOpen_Source
env.cpp
execution_context.cpp
expanduser.cpp
fin/fin_interface.cpp
find_controls.cpp
find_db.cpp
fused_api.cpp
Expand Down
286 changes: 286 additions & 0 deletions src/fin/fin_interface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#include <type_traits>
#include <utility>

#include <miopen/fin/fin_interface.hpp>
#include <miopen/conv/solvers.hpp>
#include <miopen/solver_id.hpp>

namespace miopen {
namespace fin_interface {

// ================== Solver ==================
Solver::Solver(const miopen::solver::SolverBase* solver_base, uint64_t solver_id)
: sbase(solver_base), id(solver_id)
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusInternalError);
}

Solver::Solver(const std::string& requested_name) : rname(requested_name) {}

bool Solver::IsValid() const { return sbase != nullptr; }

uint64_t Solver::GetId() const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

return id;
}

const std::string& Solver::GetName() const
{
if(sbase != nullptr)
return sbase->SolverDbId();
else
return rname;
}

bool Solver::IsTunable() const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

return sbase->IsTunable();
}

bool Solver::IsDynamic() const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

return sbase->IsDynamic();
}

// ================== SolverMixin ==================
template <class Context, class Problem>
bool SolverMixin<Context, Problem>::IsApplicable(const Context& ctx, const Problem& problem) const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

using SolverInterface = miopen::solver::SolverInterface<Context, Problem>;
return static_cast<const SolverInterface*>(sbase)->IsApplicable(ctx, problem);
}

template <class Context, class Problem>
size_t SolverMixin<Context, Problem>::GetWorkspaceSize(const Context& ctx,
const Problem& problem) const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

using SolverInterface = miopen::solver::SolverInterface<Context, Problem>;
return static_cast<const SolverInterface*>(sbase)->GetWorkspaceSize(ctx, problem);
}

template <class Context, class Problem>
miopen::solver::ConvSolution
SolverMixin<Context, Problem>::FindSolution(const Context& ctx,
const Problem& problem,
miopen::PerformanceDb& db,
const miopen::AnyInvokeParams& invoke_ctx,
const std::string& perf_cfg) const
{
std::ignore = ctx;
std::ignore = problem;
std::ignore = db;
std::ignore = invoke_ctx;
std::ignore = perf_cfg;

if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

/// \todo
MIOPEN_THROW(miopenStatusNotImplemented);
}

template <class Context, class Problem>
std::vector<miopen::solver::ConvSolution>
SolverMixin<Context, Problem>::GetAllSolutions(const Context& ctx, const Problem& problem) const
{
std::ignore = ctx;
std::ignore = problem;

if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

/// \todo
MIOPEN_THROW(miopenStatusNotImplemented);
}

template <class Context, class Problem>
std::string SolverMixin<Context, Problem>::GetPerfCfgParams(const Context& ctx,
const Problem& problem,
const PerformanceDb& db) const
{
std::ignore = ctx;
std::ignore = problem;
std::ignore = db;

if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

/// \todo
MIOPEN_THROW(miopenStatusNotImplemented);
}

template <class Context, class Problem>
bool SolverMixin<Context, Problem>::TestPerfCfgParams(const Context& ctx,
const Problem& problem,
const std::string& params) const
{
std::ignore = ctx;
std::ignore = problem;
std::ignore = params;

if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

/// \todo
MIOPEN_THROW(miopenStatusNotImplemented);
}

// Explicit instantiation
template class SolverMixin<miopen::ExecutionContext, miopen::conv::ProblemDescription>;
template class SolverMixin<miopen::ExecutionContext, miopen::batchnorm::ProblemDescription>;

// ================== ConvSolver ==================
ConvSolver::ConvSolver(const miopen::solver::SolverBase* solver_base,
uint64_t solver_id,
miopenConvAlgorithm_t algo_)
: SolverMixin(solver_base, solver_id), algo(algo_)
{
}

std::string ConvSolver::GetAlgo(miopen::conv::Direction dir) const
{
if(sbase == nullptr)
MIOPEN_THROW(miopenStatusNotInitialized);

return ConvolutionAlgoToDirectionalString(algo, dir);
}

// ================== FinInterface ==================
namespace {

template <class Solver>
struct SolverToPrimitive;

template <>
struct SolverToPrimitive<ConvSolver>
{
static auto GetPrimitive() { return miopen::solver::Primitive::Convolution; }
};

template <>
struct SolverToPrimitive<BatchNormSolver>
{
static auto GetPrimitive() { return miopen::solver::Primitive::Batchnorm; }
};

} // namespace

template <class Solver>
const std::vector<Solver>& GetAllSolvers()
{
static const auto solvers = [] {
const auto& ids = GetSolversByPrimitive(SolverToPrimitive<Solver>::GetPrimitive());
std::vector<Solver> solvers;
solvers.reserve(ids.size());

for(const auto& id : ids)
{
if(!id.IsValid())
MIOPEN_THROW(miopenStatusInternalError);

if constexpr(std::is_same_v<Solver, ConvSolver>)
solvers.emplace_back(Solver{id.GetSolverBase(), id.Value(), id.GetAlgo()});
else
solvers.emplace_back(Solver{id.GetSolverBase(), id.Value()});
}

return solvers;
}();
return solvers;
}

template <class Solver>
Solver GetSolver(const std::string& name)
{
const auto id = miopen::solver::Id{name};
if(!id.IsValid())
return {name};

if constexpr(std::is_same_v<Solver, ConvSolver>)
return {id.GetSolverBase(), id.Value(), id.GetAlgo()};
else
return {id.GetSolverBase(), id.Value()};
}

namespace {

template <class Solver>
std::vector<Solver> GetSolvers(const std::vector<std::string>& names)
{
std::vector<Solver> solvers;
solvers.reserve(names.size());
for(const auto& name : names)
solvers.emplace_back(GetSolver<Solver>(name));
return solvers;
}

} // namespace

const std::vector<ConvSolver>& GetAllConvSolvers() { return GetAllSolvers<ConvSolver>(); }

std::vector<ConvSolver> GetConvSolvers(const std::vector<std::string>& names)
{
return GetSolvers<ConvSolver>(names);
}

ConvSolver GetConvSolver(const std::string& name) { return GetSolver<ConvSolver>(name); }

const std::vector<BatchNormSolver>& GetAllBatchNormSolvers()
{
return GetAllSolvers<BatchNormSolver>();
}

std::vector<BatchNormSolver> GetBatchNormSolvers(const std::vector<std::string>& names)
{
return GetSolvers<BatchNormSolver>(names);
}

BatchNormSolver GetBatchNormSolver(const std::string& name)
{
return GetSolver<BatchNormSolver>(name);
}

} // namespace fin_interface
} // namespace miopen
Loading