-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for QONNX Resize
node ingestion and tested with tiny UNet model
#1122
Merged
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
a761da9
Added support for `Resize` node from QONNX model
nghielme b62468a
Added a test on tiny UNet model in order to test `Resize` node
nghielme 40a431f
pre-commit restyling
nghielme be55945
Aesthetic fix
nghielme 743831f
Second aesthetic fix
nghielme aa46dbe
Merge branch 'main' into resize_pr
nghielme 4f82810
Added one test on a simpler model extracted from UNet model `branched…
nghielme 7e6b9af
Example models commit updated
nghielme 5757ac6
An empty list is now appended to the shape of all the inputs of the c…
nghielme 5e13800
Merge branch 'main' into resize_pr
nghielme cf80f64
Cleaned some code and added the removal of RoI input from `Resize` node
nghielme c7f6983
Merge branch 'resize_pr' of https://github.com/fastmachinelearning/hl…
nghielme a5e32c5
Merge branch 'main' into resize_pr
nghielme b07e998
revert some unneeded changes
jmitrevs 354b535
Added some minor checks related to sizes parameter
nghielme 3b5f8db
Merge branch 'resize_pr' of https://github.com/fastmachinelearning/hl…
nghielme 3254942
Merge branch 'main' into resize_pr
nghielme 9943350
Minor fix
nghielme 5ff517b
Minor modification of the error msg
nghielme 6a10129
Minor fixes
nghielme 20ab44f
Merge branch 'main' into resize_pr
nghielme File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule example-models
updated
from d40894 to 6a82da
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from hls4ml.model.layers import Constant, Resize | ||
from hls4ml.model.optimizer import OptimizerPass | ||
|
||
|
||
class ResizeConstant(OptimizerPass): | ||
""" | ||
To compute the output shape of resize is necessary to access the scales, that | ||
are stored as initilizer, later on converted as constant inputs. | ||
""" | ||
|
||
def match(self, node): | ||
is_match = isinstance(node, Resize) and len(node.inputs) > 1 and node.get_input_node(node.inputs[-1]) | ||
return is_match | ||
|
||
def transform(self, model, node): | ||
""" | ||
Remove Constant from new shape input. Note, input shape node is already used on initialize | ||
""" | ||
scales_node = node.get_input_node(node.inputs[-1]) | ||
node.inputs[-1] = '' | ||
scales_values = scales_node.get_attr('value') | ||
node.set_attr('out_width', int(node.get_attr('in_width') * scales_values[1])) | ||
node.set_attr('out_height', int(node.get_attr('in_height') * scales_values[2])) | ||
if not isinstance(scales_node, Constant): | ||
raise RuntimeError("Non-constant shape inputs are not supported") | ||
model.remove_node(scales_node, rewire=False) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this identical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the output of the tests if you run the code with the original line:
Solved if you update it as proposed.