Skip to content

Commit

Permalink
Fix input shape for matmulnbits to fold input via reshape to 1d input
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Dec 9, 2024
1 parent 4b15b6c commit 88472af
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/onnx/parse_matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ struct parse_matmulnbits : op_parser<parse_matmulnbits>
". Actual dims: " + to_string_range(args[1]->get_shape().lens()));

std::vector<size_t> expected_scales_lens{n * n_blocks_per_col};
if(args[2]->get_shape().lens() != expected_scales_lens)

// Reshape anything larger than 1 dimension into a 1d tensor so we can check if we have the right amount of elements.
auto scale_input = args[2];
if(scale_input->get_shape().lens().size() > 1)
{
scale_input = info.add_instruction(make_op("reshape", {{"dims", {scale_input->get_shape().elements()}}}), scale_input);
}

if(scale_input->get_shape().lens() != expected_scales_lens)
MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " +
to_string_range(expected_scales_lens) +
". Actual dims: " + to_string_range(args[2]->get_shape().lens()));
Expand Down

0 comments on commit 88472af

Please sign in to comment.