Skip to content

Commit

Permalink
MO support of RDFT and IRDFT (openvinotoolkit#11690)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Bencer authored May 21, 2022
1 parent 714601c commit a859024
Show file tree
Hide file tree
Showing 22 changed files with 1,131 additions and 261 deletions.
6 changes: 5 additions & 1 deletion tools/mo/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,14 @@ openvino/tools/mo/front/tf/Cast_ext.py
openvino/tools/mo/front/tf/ClipByValue_ext.py
openvino/tools/mo/front/tf/ClipByValueTFTransformation.py
openvino/tools/mo/front/tf/common.py
openvino/tools/mo/front/tf/complex_ext.py
openvino/tools/mo/front/tf/ComplexAbs.py
openvino/tools/mo/front/tf/ComplexAbsAfterComplex.py
openvino/tools/mo/front/tf/concat.py
openvino/tools/mo/front/tf/concat_ext.py
openvino/tools/mo/front/tf/const_ext.py
openvino/tools/mo/front/tf/conv_ext.py
openvino/tools/mo/front/tf/CorrectPaddingsForPadAfterComplex.py
openvino/tools/mo/front/tf/CorrectRollAxes.py
openvino/tools/mo/front/tf/crop_and_resize_ext.py
openvino/tools/mo/front/tf/CropAndResizeReplacement.py
Expand Down Expand Up @@ -621,6 +623,7 @@ openvino/tools/mo/front/tf/rfcn_support.json
openvino/tools/mo/front/tf/rfcn_support_api_v1.10.json
openvino/tools/mo/front/tf/rfcn_support_api_v1.13.json
openvino/tools/mo/front/tf/rfcn_support_api_v1.14.json
openvino/tools/mo/front/tf/RFFTRealImagToRFFTSplit.py
openvino/tools/mo/front/tf/roll_ext.py
openvino/tools/mo/front/tf/RollRealImagPack.py
openvino/tools/mo/front/tf/scatter_nd_ext.py
Expand All @@ -647,7 +650,6 @@ openvino/tools/mo/front/tf/ssd_toolbox_detection_output.json
openvino/tools/mo/front/tf/ssd_toolbox_multihead_detection_output.json
openvino/tools/mo/front/tf/ssd_v2_support.json
openvino/tools/mo/front/tf/SSDToolboxDetectionOutput.py
openvino/tools/mo/front/tf/SSliceComplex.py
openvino/tools/mo/front/tf/swap_deconv_inputs.py
openvino/tools/mo/front/tf/swish_ext.py
openvino/tools/mo/front/tf/SwitchMergeOptimization.py
Expand Down Expand Up @@ -813,6 +815,7 @@ openvino/tools/mo/middle/SliceLikeToStridedSlice.py
openvino/tools/mo/middle/sparse_reshape.py
openvino/tools/mo/middle/split_tdnn_memoryoffset.py
openvino/tools/mo/middle/SplitConcatPairToInterpolate.py
openvino/tools/mo/middle/SSliceComplex.py
openvino/tools/mo/middle/StridedSliceNormalizer.py
openvino/tools/mo/middle/SwapAxesMiddleReplacer.py
openvino/tools/mo/middle/TensorIterator_utils.py
Expand Down Expand Up @@ -862,6 +865,7 @@ openvino/tools/mo/ops/bucketize.py
openvino/tools/mo/ops/Cast.py
openvino/tools/mo/ops/clamp.py
openvino/tools/mo/ops/ClipByValueTF.py
openvino/tools/mo/ops/Complex.py
openvino/tools/mo/ops/concat.py
openvino/tools/mo/ops/const.py
openvino/tools/mo/ops/constant_fill.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# SPDX-License-Identifier: Apache-2.0


from openvino.tools.mo.ops.elementwise import Add, Pow
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.subgraph_matcher import SubgraphMatch
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.middle.passes.convert_data_type import data_type_str_to_np
from openvino.tools.mo.ops.elementwise import Add, Pow


class ComplexAbsAfterComplex(FrontReplacementSubgraph):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.subgraph_matcher import SubgraphMatch
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph
from openvino.tools.mo.ops.concat import Concat


class CorrectPaddingsForPadAfterComplex(FrontReplacementSubgraph):
"""
There are TF models with the TF operation Complex that has two real tensors as arguments and returns the complex
tensor with real and imaginary parts given as arguments in port 0 and 1 respectively.
Although TF has a native support of complex numbers, OpenVINO doesn't have such support and emulates a complex
tensor with the shape [N_0, ..., N_{r - 1}] as a real tensor of the shape [N_0, ..., N_{r - 1}, 2] interpreting
any complex number as a tuple of the form
(real part, imaginary part)
That is, the emulated complex tensor has the rank r + 1, not r as in the TF model.
Hence, when we convert a subgraph of the form
Complex
|
|
Pad
we should correct pads_begin and pads_end adding zero at the end of pads_begin and pads_end.
The transformation performs such corrections.
"""
enabled = True

def run_after(self):
from openvino.tools.mo.front.tf.pad_tf_to_pad import PadTFToPad
return [PadTFToPad]

def pattern(self):
return dict(
nodes=[
('complex', dict(op='Complex')),
('pad', dict(op='Pad')),
],
edges=[
('complex', 'pad', {'in': 0}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
pad_node = match['pad']
pads_begin_node = pad_node.in_port(1).get_source().node
pads_end_node = pad_node.in_port(2).get_source().node

pads_begin_node_name = pads_begin_node.soft_get('name', pads_begin_node.id)
pads_end_node_name = pads_end_node.soft_get('name', pads_end_node.id)

concat_for_pads_begin = create_op_with_const_inputs(graph, Concat,
{1: int64_array([0])},
{
'name': pads_begin_node_name + '/additional',
'in_ports_count': 2,
'axis': 0,
})
concat_for_pads_end = create_op_with_const_inputs(graph, Concat,
{1: int64_array([0])},
{
'name': pads_end_node_name + '/additional',
'in_ports_count': 2,
'axis': 0,
})
pad_node.in_port(1).get_connection().insert_node(concat_for_pads_begin)
pad_node.in_port(2).get_connection().insert_node(concat_for_pads_end)
12 changes: 4 additions & 8 deletions tools/mo/openvino/tools/mo/front/tf/CorrectRollAxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,13 @@

class CorrectRollAxes(FrontReplacementSubgraph):
"""
The transformation SSliceComplex removes 2 StridedSlice and Complex operation. If the Roll node is a consumer
of Complex node in the original TF model, then we have a real input tensor for Roll instead of a complex.
Negative axes values for the Roll operation should be updated to reflect the fact that the rank of input tensor was
increased by one (a new trailing dimension of size 2 containing real and imaginary part of complex number is added).
If the Roll node is a consumer of Complex node in the original TF model, then we have a real input tensor for Roll
instead of a complex. Negative axes values for the Roll operation should be updated to reflect the fact that the
rank of input tensor was increased by one (a new trailing dimension of size 2 containing real and imaginary part
of complex number is added).
"""
enabled = True

def run_after(self):
from openvino.tools.mo.front.tf.SSliceComplex import SSliceComplex
return [SSliceComplex]

def find_and_replace_pattern(self, graph: Graph):
for roll in graph.get_op_nodes(op='Roll', input_rank_changed=True):
add_constant_to_negative_values(roll, 2, int64_array(-1))
Expand Down
66 changes: 66 additions & 0 deletions tools/mo/openvino/tools/mo/front/tf/RFFTRealImagToRFFTSplit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.subgraph_matcher import SubgraphMatch
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.split import Split
from openvino.tools.mo.ops.squeeze import Squeeze


class RFFTRealImagToRDFTSplit(FrontReplacementSubgraph):
"""
This transformation converts the operation TFRFFT into OpenVINO RDFT.
"""
enabled = True

def run_before(self):
from openvino.tools.mo.front.tf.TFFFTToDFT import TFFFTToDFT
return [TFFFTToDFT]

def pattern(self):
return dict(
nodes=[
('rfft', dict(op='TFFFT', fft_kind='RDFT')),
('real', dict(op='Real')),
('imag', dict(op='Imag')),
],
edges=[
('rfft', 'real', {'in': 0}),
('rfft', 'imag', {'in': 0}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
rfft_node = match['rfft']
real_node = match['real']
imag_node = match['imag']

rfft_name = rfft_node.soft_get('name', rfft_node.id)
real_name = rfft_node.soft_get('name', real_node.id)
imag_name = rfft_node.soft_get('name', imag_node.id)
split_node = create_op_with_const_inputs(graph, Split, {1: int64_array(-1)},
{
'name': rfft_name + '/split',
'num_splits': 2,
'out_ports_count': 2
})
squeeze_real = create_op_with_const_inputs(graph, Squeeze, {1: int64_array(-1)},
{'name': rfft_name + '/squeeze_real'})
squeeze_imag = create_op_with_const_inputs(graph, Squeeze, {1: int64_array(-1)},
{'name': rfft_name + '/squeeze_imag'})

split_node.out_port(0).connect(squeeze_real.in_port(0))
split_node.out_port(1).connect(squeeze_imag.in_port(0))
real_node.out_port(0).get_connection().set_source(squeeze_real.out_port(0))
imag_node.out_port(0).get_connection().set_source(squeeze_imag.out_port(0))

rfft_node.out_port(0).connect(split_node.in_port(0))

rename_nodes([(real_node, real_name + '/to_be_removed'), (squeeze_real, real_name)])
rename_nodes([(imag_node, imag_name + '/to_be_removed'), (squeeze_imag, imag_name)])

real_node.in_port(0).disconnect()
imag_node.in_port(0).disconnect()
4 changes: 0 additions & 4 deletions tools/mo/openvino/tools/mo/front/tf/RollRealImagPack.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class RollRealImagPack(FrontReplacementSubgraph):
"""
enabled = True

def run_after(self):
from openvino.tools.mo.front.tf.SSliceComplex import SSliceComplex
return [SSliceComplex]

def run_before(self):
from openvino.tools.mo.front.Pack import Pack
return [Pack]
Expand Down
70 changes: 0 additions & 70 deletions tools/mo/openvino/tools/mo/front/tf/SSliceComplex.py

This file was deleted.

28 changes: 22 additions & 6 deletions tools/mo/openvino/tools/mo/front/tf/TFFFTToDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
# SPDX-License-Identifier: Apache-2.0


from openvino.tools.mo.ops.dft import DFT, IDFT
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.common.replacement import FrontReplacementSubgraph
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph, rename_nodes
from openvino.tools.mo.ops.dft import DFT, IDFT, IRDFT, RDFT


class TFFFTToDFT(FrontReplacementSubgraph):
"""
This transformation converts the operation TFFFT into OpenVINO DFT (if the attribute 'is_inverse' is False),
or into OpenVINO IDFT (otherwise).
This transformation converts the operation TFFFT into OpenVINO operations DFT, RDFT, IDFT, or IRDFT,
according to the following rules:
1) FFT, FFT2D, FFT3D are converted into DFT;
2) IFFT, IFFT2D, IFFT3D are converted into IDFT;
3) RFFT, RFFT2D, RFFT3D are converted into RDFT;
4) IRFFT, IRFFT2D, IRFFT3D are converted into IRDFT.
"""
enabled = True

Expand All @@ -26,9 +30,21 @@ def find_and_replace_pattern(self, graph: Graph):

num_of_dims = tf_fft.soft_get('num_of_dimensions', 1)
axes = int64_array(range(-num_of_dims, 0))
op = IDFT if tf_fft.soft_get('is_inverse', False) else DFT
dft_node = create_op_with_const_inputs(graph, op, {1: axes}, {'in_ports_count': 2},
tf_fft.in_port(0).get_source().node)

fft_kind = tf_fft['fft_kind']
assert fft_kind in ['DFT', 'IDFT', 'RDFT', 'IRDFT'], \
'Node {} with the operation TFFFT supports only the following FFT-like operations: ' \
'DFT, IDFT, RDFT, IRDFT. Got: {}'.format(tf_fft_name, fft_kind)

op = {'DFT': DFT, 'IDFT': IDFT, 'RDFT': RDFT, 'IRDFT': IRDFT}[fft_kind]

if fft_kind in ['DFT', 'IDFT'] or not tf_fft.is_in_port_connected(1):
dft_node = create_op_with_const_inputs(graph, op, {1: axes}, {'in_ports_count': 2},
tf_fft.in_port(0).get_source().node)
else:
dft_node = create_op_with_const_inputs(graph, op, {1: axes}, {'in_ports_count': 3},
tf_fft.in_port(0).get_source().node)
tf_fft.in_port(1).get_source().connect(dft_node.in_port(2))

tf_fft.out_port(0).get_connection().set_source(dft_node.out_port(0))

Expand Down
15 changes: 15 additions & 0 deletions tools/mo/openvino/tools/mo/front/tf/complex_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from openvino.tools.mo.ops.Complex import Complex
from openvino.tools.mo.front.extractor import FrontExtractorOp


class ComplexOpFrontExtractor(FrontExtractorOp):
op = 'Complex'
enabled = True

@classmethod
def extract(cls, node):
Complex.update_node_stat(node, {})
return cls.enabled
Loading

0 comments on commit a859024

Please sign in to comment.