forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MO support of RDFT and IRDFT (openvinotoolkit#11690)
- Loading branch information
Mateusz Bencer
authored
May 21, 2022
1 parent
714601c
commit a859024
Showing
22 changed files
with
1,131 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
tools/mo/openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array | ||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph | ||
from openvino.tools.mo.front.subgraph_matcher import SubgraphMatch | ||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs | ||
from openvino.tools.mo.graph.graph import Graph | ||
from openvino.tools.mo.ops.concat import Concat | ||
|
||
|
||
class CorrectPaddingsForPadAfterComplex(FrontReplacementSubgraph): | ||
""" | ||
There are TF models with the TF operation Complex that has two real tensors as arguments and returns the complex | ||
tensor with real and imaginary parts given as arguments in port 0 and 1 respectively. | ||
Although TF has a native support of complex numbers, OpenVINO doesn't have such support and emulates a complex | ||
tensor with the shape [N_0, ..., N_{r - 1}] as a real tensor of the shape [N_0, ..., N_{r - 1}, 2] interpreting | ||
any complex number as a tuple of the form | ||
(real part, imaginary part) | ||
That is, the emulated complex tensor has the rank r + 1, not r as in the TF model. | ||
Hence, when we convert a subgraph of the form | ||
Complex | ||
| | ||
| | ||
Pad | ||
we should correct pads_begin and pads_end adding zero at the end of pads_begin and pads_end. | ||
The transformation performs such corrections. | ||
""" | ||
enabled = True | ||
|
||
def run_after(self): | ||
from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad | ||
return [PadTFToPad] | ||
|
||
def pattern(self): | ||
return dict( | ||
nodes=[ | ||
('complex', dict(op='Complex')), | ||
('pad', dict(op='Pad')), | ||
], | ||
edges=[ | ||
('complex', 'pad', {'in': 0}), | ||
]) | ||
|
||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): | ||
pad_node = match['pad'] | ||
pads_begin_node = pad_node.in_port(1).get_source().node | ||
pads_end_node = pad_node.in_port(2).get_source().node | ||
|
||
pads_begin_node_name = pads_begin_node.soft_get('name', pads_begin_node.id) | ||
pads_end_node_name = pads_end_node.soft_get('name', pads_end_node.id) | ||
|
||
concat_for_pads_begin = create_op_with_const_inputs(graph, Concat, | ||
{1: int64_array([0])}, | ||
{ | ||
'name': pads_begin_node_name + '/additional', | ||
'in_ports_count': 2, | ||
'axis': 0, | ||
}) | ||
concat_for_pads_end = create_op_with_const_inputs(graph, Concat, | ||
{1: int64_array([0])}, | ||
{ | ||
'name': pads_end_node_name + '/additional', | ||
'in_ports_count': 2, | ||
'axis': 0, | ||
}) | ||
pad_node.in_port(1).get_connection().insert_node(concat_for_pads_begin) | ||
pad_node.in_port(2).get_connection().insert_node(concat_for_pads_end) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
tools/mo/openvino/tools/mo/front/tf/RFFTRealImagToRFFTSplit.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array | ||
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph | ||
from openvino.tools.mo.front.subgraph_matcher import SubgraphMatch | ||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs | ||
from openvino.tools.mo.graph.graph import Graph, rename_nodes | ||
from openvino.tools.mo.ops.split import Split | ||
from openvino.tools.mo.ops.squeeze import Squeeze | ||
|
||
|
||
class RFFTRealImagToRDFTSplit(FrontReplacementSubgraph): | ||
""" | ||
This transformation converts the operation TFRFFT into OpenVINO RDFT. | ||
""" | ||
enabled = True | ||
|
||
def run_before(self): | ||
from openvino.tools.mo.front.tf.TFFFTToDFT import TFFFTToDFT | ||
return [TFFFTToDFT] | ||
|
||
def pattern(self): | ||
return dict( | ||
nodes=[ | ||
('rfft', dict(op='TFFFT', fft_kind='RDFT')), | ||
('real', dict(op='Real')), | ||
('imag', dict(op='Imag')), | ||
], | ||
edges=[ | ||
('rfft', 'real', {'in': 0}), | ||
('rfft', 'imag', {'in': 0}), | ||
]) | ||
|
||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): | ||
rfft_node = match['rfft'] | ||
real_node = match['real'] | ||
imag_node = match['imag'] | ||
|
||
rfft_name = rfft_node.soft_get('name', rfft_node.id) | ||
real_name = rfft_node.soft_get('name', real_node.id) | ||
imag_name = rfft_node.soft_get('name', imag_node.id) | ||
split_node = create_op_with_const_inputs(graph, Split, {1: int64_array(-1)}, | ||
{ | ||
'name': rfft_name + '/split', | ||
'num_splits': 2, | ||
'out_ports_count': 2 | ||
}) | ||
squeeze_real = create_op_with_const_inputs(graph, Squeeze, {1: int64_array(-1)}, | ||
{'name': rfft_name + '/squeeze_real'}) | ||
squeeze_imag = create_op_with_const_inputs(graph, Squeeze, {1: int64_array(-1)}, | ||
{'name': rfft_name + '/squeeze_imag'}) | ||
|
||
split_node.out_port(0).connect(squeeze_real.in_port(0)) | ||
split_node.out_port(1).connect(squeeze_imag.in_port(0)) | ||
real_node.out_port(0).get_connection().set_source(squeeze_real.out_port(0)) | ||
imag_node.out_port(0).get_connection().set_source(squeeze_imag.out_port(0)) | ||
|
||
rfft_node.out_port(0).connect(split_node.in_port(0)) | ||
|
||
rename_nodes([(real_node, real_name + '/to_be_removed'), (squeeze_real, real_name)]) | ||
rename_nodes([(imag_node, imag_name + '/to_be_removed'), (squeeze_imag, imag_name)]) | ||
|
||
real_node.in_port(0).disconnect() | ||
imag_node.in_port(0).disconnect() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (C) 2018-2022 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from openvino.tools.mo.ops.Complex import Complex | ||
from openvino.tools.mo.front.extractor import FrontExtractorOp | ||
|
||
|
||
class ComplexOpFrontExtractor(FrontExtractorOp): | ||
op = 'Complex' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
Complex.update_node_stat(node, {}) | ||
return cls.enabled |
Oops, something went wrong.