Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify reshapes #2099

Merged
merged 61 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
612ad42
Add reshape_lazy.hpp to perform aliasing
TedThemistokleous Jun 30, 2023
979d205
Add reshape lazy to arguments
TedThemistokleous Jul 28, 2023
71e3023
Simplify reshape now that we assume reshape_lazy to perform aliasing
TedThemistokleous Jul 28, 2023
86ec62c
Remove can_strides_merge from reshape
TedThemistokleous Aug 4, 2023
9d55b0f
Always add contiguous to reshapes
TedThemistokleous Aug 4, 2023
a6cfcff
Add matcher for reshape_alias
TedThemistokleous Aug 4, 2023
17a0e9d
auto_contiguous always add contiguous after reshapes
TedThemistokleous Aug 9, 2023
f75d0fd
Backup of changes.
TedThemistokleous Aug 11, 2023
7cff9cf
Add copy in reshape to make reshape by default perform a copy
TedThemistokleous Aug 15, 2023
c7ed62c
Remove auto contiguous for reshape.
TedThemistokleous Aug 15, 2023
cea2f20
Remove contiguous being added with onnx/tf parsing of reshape()
TedThemistokleous Aug 15, 2023
a6565c8
Work in progress. Getting segfaults now with eval
TedThemistokleous Aug 15, 2023
717f733
Update lowering and reshape to work for reshape_copy
TedThemistokleous Aug 16, 2023
7527b13
Add reshape lazy after contiguous branch for try/catch in lowering.
TedThemistokleous Aug 17, 2023
67e7842
Remove contiguous from auto_contiguous_test
TedThemistokleous Aug 18, 2023
de2040b
Add test for reshape_lazy
TedThemistokleous Aug 18, 2023
f2eee04
Update op_shape_test for reshape->reshape_lazy rename
TedThemistokleous Aug 18, 2023
81a679a
Update rewrite_pooling_test for reshape->reshape_lazy
TedThemistokleous Aug 18, 2023
5c1b156
Replace reshape with reshape_lazy in simplify_qdq_test
TedThemistokleous Aug 18, 2023
e7ce5fa
Update ref_ops_test with reshape->reshape_lazy
TedThemistokleous Aug 18, 2023
f744bf9
remove reshape_lazy from argument
TedThemistokleous Aug 22, 2023
6039e18
Set gather_elements to use aliased reshape instead of copy reshape
TedThemistokleous Aug 24, 2023
ee84972
Fix resize/upsample ops to use reshape_lazy
TedThemistokleous Aug 24, 2023
4780ef8
Fix rewrite_pooling_test to include reshape_lazy
TedThemistokleous Aug 24, 2023
c2bf3e9
Fix spacetodeath and depth to space ops for reshape lazy
TedThemistokleous Aug 24, 2023
c14598f
Cleanup unused reshaper_op_names()
TedThemistokleous Aug 25, 2023
ac300ab
Refactor Reshape
TedThemistokleous Aug 25, 2023
20fa6c9
Remove Try-catch for reshape_lazy in lowering
TedThemistokleous Aug 25, 2023
7d37922
Revert "Fix spacetodeath and depth to space ops for reshape lazy"
TedThemistokleous Aug 25, 2023
75c6fec
Revert "Fix resize/upsample ops to use reshape_lazy"
TedThemistokleous Aug 25, 2023
e9bcc23
Revert "Set gather_elements to use aliased reshape instead of copy re…
TedThemistokleous Aug 25, 2023
3916377
Revert "Replace reshape with reshape_lazy in simplify_qdq_test"
TedThemistokleous Aug 25, 2023
a727dbb
Add reshape tests into ref_ops_test
TedThemistokleous Aug 25, 2023
dcac528
Fix #define guard for MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY
TedThemistokleous Aug 26, 2023
628580f
clang-Tidy & clang-format cleanup
TedThemistokleous Aug 26, 2023
8a0aeed
Change comment in reshape MIGRAPHX_THROW
TedThemistokleous Aug 28, 2023
65ffb72
Review updates; add contiguous before and after reshape
CharlieL7 Sep 8, 2023
7b46693
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
CharlieL7 Sep 11, 2023
2a8407a
Fix insert instruction
CharlieL7 Sep 11, 2023
b84ca27
add comment
CharlieL7 Sep 11, 2023
73910d0
Add back ref_ops tests for reshape
CharlieL7 Sep 11, 2023
949d46b
change return to the last contiguous
CharlieL7 Sep 11, 2023
20a59c6
Merge branch 'develop' into modify_reshapes_develop
TedThemistokleous Sep 12, 2023
e917c77
Fix Format
TedThemistokleous Sep 13, 2023
995f238
Add reshape tests back into op_shape_test
TedThemistokleous Sep 14, 2023
19461d3
Remove extra find_reshape_cont{}
TedThemistokleous Sep 19, 2023
af99bec
Fix how reshape_lazy and contiguous blocks replace reshape operator.
TedThemistokleous Sep 19, 2023
6d19c16
Fix input to lazy reshape to be from contiguous before.
TedThemistokleous Sep 20, 2023
b4bdb71
Add back assertions for reshape and fix tests
TedThemistokleous Sep 20, 2023
fece843
fixup! Fix input to lazy reshape to be from contiguous before.
TedThemistokleous Sep 20, 2023
0a92928
Add reshape_dims() into reshape
TedThemistokleous Sep 20, 2023
aaa76f9
Add test to verify modifying memory correctly via the copy for reshape
TedThemistokleous Sep 20, 2023
f820966
fixup! Add reshape_dims() into reshape
TedThemistokleous Sep 20, 2023
31d3055
[license] Update date in lowering.cpp
TedThemistokleous Sep 20, 2023
bf31df9
Fix format for reshape.hpp
TedThemistokleous Sep 20, 2023
75407f9
Change assert condition in reshape.hpp to >= instead of ==
TedThemistokleous Sep 20, 2023
1a84aa2
Fix comment in reshape
TedThemistokleous Sep 22, 2023
df9b7c0
Change shape on reshape_broadcast_squeeze_memlayout_change in op_shap…
TedThemistokleous Sep 22, 2023
0625bb8
Remove size assert since outputs can change mem layout with nonstanda…
TedThemistokleous Sep 22, 2023
35b67f7
Merge branch 'develop' into modify_reshapes_develop
causten Sep 26, 2023
bed80c2
Merge branch 'develop' into modify_reshapes_develop
causten Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ register_migraphx_ops(
reduce_sum
relu
reshape
reshape_lazy
reverse
rnn
rnn_last_cell_output
Expand Down
1 change: 0 additions & 1 deletion src/auto_contiguous.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>

#include <migraphx/iterator_for.hpp>

namespace migraphx {
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/instruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction

const std::vector<module_ref>& module_inputs() const;

/// Where this instruction is used as an input to another instruction
const std::vector<instruction_ref>& outputs() const;

friend bool operator==(const instruction& x, const instruction& y);
Expand Down
143 changes: 20 additions & 123 deletions src/include/migraphx/op/reshape.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -29,7 +29,8 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>

#include <algorithm>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -45,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims"));
}

value attributes() const { return {{"require_std_shape", true}}; }

std::string name() const { return "reshape"; }

shape dyn_compute_shape(shape s0) const
Expand Down Expand Up @@ -97,112 +96,6 @@ struct reshape
return {s0.type(), output_dyn_dims};
}

template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}

template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}

// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};

const auto& idims = input.lens();
const auto& istrides = input.strides();

std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}

// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}

if(rdims.size() != rstrides.size())
return nullopt;

return shape{input.type(), rdims, rstrides};
}

shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
Expand Down Expand Up @@ -232,26 +125,26 @@ struct reshape
}
}

auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
shape s{inputs.front().type(), rdims};

if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " +
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));

assert(s->bytes() == inputs.front().bytes());
return *s;
assert(s.bytes() == inputs.front().bytes());
return s;
}

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);

auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");

auto s0 = inputs.front();
if(s0.dynamic())
{
return dyn_compute_shape(s0);
Expand All @@ -264,10 +157,14 @@ struct reshape

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
assert(dyn_out.computed_shape.standard());
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
argument result{dyn_out.computed_shape};

std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
};

} // namespace op
Expand Down
Loading
Loading