Skip to content

Commit

Permalink
Merge branch 'develop' into fix_assumption
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Sep 5, 2023
2 parents 9f6bacb + 1403517 commit 09f35e7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,8 @@ struct mlir_program
void set_gpu_properties(const context& migraphx_ctx)
{
const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
}

std::pair<std::size_t, std::size_t> get_launch_params() const
Expand Down Expand Up @@ -869,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx,
adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});

static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl;
}

mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
Expand Down

0 comments on commit 09f35e7

Please sign in to comment.