Skip to content

Commit

Permalink
operators: Add a cond operator.
Browse files Browse the repository at this point in the history
This adds a cond operator. An example of using it:

```python
x0 = Tensor(data_layout=types_pb2.N, tensor_data=np.array([2])
x1 = Tensor(data_layout=types_pb2.N, tensor_data=np.array([5])
y = Tensor(data_layout=types_pb2.N, tensor_data=np.array([10])
z = Tensor(data_layout=types_pb2.N, tensor_data=np.array([20])
res = control_flow_ops.cond(math_ops.less(x0, x1),
    lambda: math_ops.add(y, z), lambda: math_ops.mul(y, z))
```

Nesting is also supported.

TESTED=unit

Change-Id: I51e9983093a49a49d269a6785b38842007d4fd29
  • Loading branch information
yaoyuannnn committed Jul 29, 2020
1 parent b2c294c commit 1050e19
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 46 deletions.
1 change: 1 addition & 0 deletions make/Makefile.common
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ PY_TESTS = smaug/python/tensor_test.py \
smaug/python/ops/fp_precision_test.py \
smaug/python/ops/data_op_test.py \
smaug/python/ops/activation_ops_test.py \
smaug/python/ops/control_flow_ops_test.py \
smaug/python/ops/recurrent_test.py \
smaug/python/ops/attention_test.py

Expand Down
38 changes: 26 additions & 12 deletions smaug/python/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,35 +118,49 @@ def add_node(

return output_tensors

def get_node(self, node_name):
"""Return a node in the graph proto by its name."""
def get_node(self, node_name, recursive=False):
"""Return a node in the graph proto by its name.
Args:
node_name: Node name.
recursive: If true, recursively search the node in the parent graphs.
Returns:
A NodeProto if we find the node.
"""
for i in range(len(self.graph.nodes)):
if self.graph.nodes[i].name == node_name:
return self.graph.nodes[i]
return None
if recursive and self._parent_graph is not None:
return self._parent_graph.get_node(node_name, True)

def get_nodes(self):
"""Return nodes in the graph proto."""
return self.graph.nodes

def create_unique_name(self, name, mark_as_used=True):
def get_root_graph(self):
"""Return the root graph."""
root = self
while root._parent_graph is not None:
root = root._parent_graph
return root

def create_unique_name(self, name):
""" Create a unique name for the node.
Args:
name: The base name used to create the unique name.
mark_as_used: Mark the unique name as used so if someone wants to call
create_unique_name(unique_name), a different name will be created.
"""
root = self.get_root_graph()
new_name = name
if name in self._node_names:
if name in root._node_names:
while True:
self._node_names[name] += 1
new_name = "%s_%d" % (name, self._node_names[name])
root._node_names[name] += 1
new_name = "%s_%d" % (name, root._node_names[name])
# Make sure the new name is not already used.
if new_name not in self._node_names:
if new_name not in root._node_names:
break
if mark_as_used:
self._node_names[new_name] = 0
root._node_names[new_name] = 0
return new_name

def disable_layout_transform(self):
Expand Down
110 changes: 99 additions & 11 deletions smaug/python/ops/control_flow_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from smaug.core.types_pb2 import *
from smaug.python.global_vars import *
from smaug.python.ops.common import *
from smaug.core import types_pb2
from smaug.python import global_vars
from smaug.python import tensor_utils
from smaug.python.graph import Graph
from smaug.python.ops import common

switch_op_output_ports = {"true": 1, "false": 0}

def switch(input_tensor, pred, name="switch"):
"""Forward the input to output port determined by the given predication.
Expand All @@ -14,10 +18,8 @@ def switch(input_tensor, pred, name="switch"):
output_false, output_true: Two tensors representing the two branches of the
switch. Input will only be forwarded to the taken branch.
"""
return add_node(
name=name,
op=Switch,
input_tensors=[input_tensor, pred],
return common.add_node(
name=name, op=types_pb2.Switch, input_tensors=[input_tensor, pred],
output_tensors_dims=[input_tensor.shape.dims] * 2,
output_tensor_layout=input_tensor.shape.layout)

Expand All @@ -30,9 +32,95 @@ def merge(input_tensors, name="merge"):
Returns:
A tensor that the available input tensor forwards to.
"""
return add_node(
name=name,
op=Merge,
input_tensors=input_tensors,
return common.add_node(
name=name, op=types_pb2.Merge, input_tensors=input_tensors,
output_tensors_dims=[input_tensors[0].shape.dims],
output_tensor_layout=input_tensors[0].shape.layout)[0]

def cond(predication, true_fn, false_fn, name="cond"):
"""A conditional operator.
This operator provides the capability of doing if-else statement. Depending on
the predication value, either the True or the False body of the operator will
be executed.
Args:
predication: A predication tensor of value 0 or 1, determining which path to
execute.
true_fn: The callable to be performed if `predication` is 1.
false_fn: The callable to be performed if `predication` is 0.
Returns:
The tensors returned by either true_fn or false_fn.
"""

def _insert_switch_nodes(predication, branch_result, graph):
"""Insert switch nodes for external tensors in the subgraph.
An external tensor is a tensor that comes from a node outside this graph,
this adds switch nodes for every external tensor in `graph`.
Args:
predication: The predication tensor used for determining the deadness of
switch node results.
branch_result: String value of "true" or "false", representing which
result of the switch nodes to use.
graph: A `GraphProto` that represents a branch of the conditional.
"""
if branch_result not in ["true", "false"]:
raise ValueError(
"Use either 'true' or 'false' to indicate the output of the switch "
"nodes.")
nodes = [node for node in graph.get_nodes() if node.op != types_pb2.Data]
# This keeps track of all the tensors that come from nodes in the graph.
internal_tensors = set()
for node in nodes:
internal_tensors.update(
set([tensor.name for tensor in node.output_tensors]))
for node in nodes:
for i, tensor_proto in enumerate(node.input_tensors):
# If any input tensor of the graph appear in the graph workspace, then
# this tensor is an external to the graph and we create a switch node
# for it.
# Don't create switch node for an existing one.
if node.op == types_pb2.Switch:
continue
if tensor_proto.name not in internal_tensors:
source_node = graph.get_node(node.parents[i], True)
tensor = tensor_utils.from_tensor_proto(tensor_proto)
if source_node is not None:
tensor.source = (source_node, node.src_tensors_indices[i])
switch_result = switch(
tensor, predication)[switch_op_output_ports[branch_result]]
# Update the node with the switch node as its new parent.
switch_result.to_tensor_proto(node.input_tensors[i])
switch_node = switch_result.source[0]
node.parents[i] = switch_node.name
node.src_tensors_indices[i] = switch_op_output_ports[branch_result]

cur_graph = global_vars.get_graph()
backend = cur_graph.backend
mem_policy = cur_graph.mem_policy
name = cur_graph.create_unique_name(name)

# Build the subgraph for the true branch.
with Graph(name="%s_true_branch" % name, backend=backend,
mem_policy=mem_policy) as subgraph_t:
res_t = true_fn()
if not isinstance(res_t, (list, tuple)):
res_t = [res_t]
_insert_switch_nodes(predication, "true", subgraph_t)
cur_graph.merge(subgraph_t)

# Build the subgraph for the false branch.
with Graph(name="%s_false_branch" % name, backend=backend,
mem_policy=mem_policy) as subgraph_f:
res_f = false_fn()
if not isinstance(res_f, (list, tuple)):
res_f = [res_f]
_insert_switch_nodes(predication, "false", subgraph_f)
cur_graph.merge(subgraph_f)

# Add the merge nodes for the outputs.
merges = [merge([t, f]) for (t, f) in zip(res_t, res_f)]
return merges
145 changes: 145 additions & 0 deletions smaug/python/ops/control_flow_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python

import unittest
import numpy as np

from smaug.python.smaug_test import SmaugTest
from smaug.python import global_vars
from smaug.python.graph import Graph
from smaug.python.tensor import Tensor
from smaug.python.ops import math_ops
from smaug.python.ops import data_op
from smaug.python.ops import control_flow_ops
from smaug.core import types_pb2

class ControlFlowOpsTest(SmaugTest):
def test_cond_op_simple_func(self):
with Graph(name=self.graph_name, backend=self.backend) as graph:
x0 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
x1 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype))
y = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype))
z = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype))
expected_res = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([30], dtype=self.dtype))
# res = y + z if x0 < x1 else y * z
res = control_flow_ops.cond(
math_ops.less(x0, x1), lambda: math_ops.add(y, z),
lambda: math_ops.mul(y, z))
self.runAndValidate(graph, expected_res.tensor_data)

def test_cond_op_func_call(self):
def func(a, b):
minus_three = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([-3], dtype=self.dtype))
return math_ops.add(a, math_ops.mul(b, minus_three))

with Graph(name=self.graph_name, backend=self.backend) as graph:
x0 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
x1 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype))
y = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype))
z = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype))
expected_res = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([-50],
dtype=self.dtype))
# res = y - 3z if x0 < x1 else y * z
res = control_flow_ops.cond(
math_ops.less(x0, x1), lambda: func(y, z), lambda: math_ops.mul(y, z))
self.runAndValidate(graph, expected_res.tensor_data)

def test_nested_cond_ops(self):
def func_true(a, b):
minus_one = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([-1], dtype=self.dtype))
return control_flow_ops.cond(
math_ops.less(a, b),
lambda: math_ops.add(a, math_ops.mul(b, minus_one)),
lambda: math_ops.add(a, b))

def func_false(a, b):
two = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
return control_flow_ops.cond(
math_ops.greater(a, b), lambda: math_ops.mul(a, two),
lambda: math_ops.mul(b, two))

with Graph(name=self.graph_name, backend=self.backend) as graph:
x0 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
x1 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype))
y = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype))
z = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype))
expected_res = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([40], dtype=self.dtype))
# if x0 > x1:
# if y < z:
# res = y - z
# else:
# res = y + z
# else:
# if y > z:
# res = 2y
# else:
# res = 2z
res = control_flow_ops.cond(
math_ops.greater(x0, x1), lambda: func_true(y, z),
lambda: func_false(y, z))
self.runAndValidate(graph, expected_res.tensor_data)

def test_use_nested_op_result(self):
def func_true(a, b):
minus_one = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([-1], dtype=self.dtype))
res = control_flow_ops.cond(
math_ops.less(a, b),
lambda: math_ops.add(a, math_ops.mul(b, minus_one)),
lambda: math_ops.add(a, b))[0]
# Use the cond results before returning.
return math_ops.mul(res, res)

def func_false(a, b):
two = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
return control_flow_ops.cond(
math_ops.greater(a, b), lambda: math_ops.mul(a, two),
lambda: math_ops.mul(b, two))

with Graph(name=self.graph_name, backend=self.backend) as graph:
x0 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([2], dtype=self.dtype))
x1 = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([5], dtype=self.dtype))
y = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([10], dtype=self.dtype))
z = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([20], dtype=self.dtype))
expected_res = Tensor(
data_layout=types_pb2.N, tensor_data=np.array([100],
dtype=self.dtype))
# if x0 < x1:
# if y < z:
# res = (y - z) ^ 2
# else:
# res = y + z
# else:
# if y > z:
# res = 2y
# else:
# res = 2z
res = control_flow_ops.cond(
math_ops.less(x0, x1), lambda: func_true(y, z),
lambda: func_false(y, z))
self.runAndValidate(graph, expected_res.tensor_data)

if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions smaug/python/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def calc_padding(self, value):
return 0
return (self.shape.alignment - (value % self.shape.alignment))

def to_tensor_proto(self, tensor_proto, tensor_data_array):
def to_tensor_proto(self, tensor_proto, tensor_data_array=None):
"""Serialize the tensor into a tensor proto.
Args:
Expand All @@ -90,7 +90,7 @@ def to_tensor_proto(self, tensor_proto, tensor_data_array):
tensor_proto.shape.CopyFrom(self.shape)
tensor_proto.data_type = self.data_type
tensor_proto.data_format = self.data_format
if self.tensor_data is not None:
if self.tensor_data is not None and tensor_data_array is not None:

# Since Protobuf doesn't support float16 data type, we pack two float16
# elements into one int32.
Expand Down
21 changes: 0 additions & 21 deletions smaug/python/unique_name_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,5 @@ def test_user_supplied_names2(self):
res = math_ops.add(res, res, name="add_1")
self.assertEqual(get_node_names(test_graph), {"add", "add_1", "add_1_1"})

def test_user_supplied_names3(self):
with Graph(graph_name, backend) as test_graph:
res = math_ops.add(x, y, name="add")
unique_name = test_graph.create_unique_name("add", mark_as_used=False)
res = math_ops.add(res, res, name=unique_name)
self.assertEqual(get_node_names(test_graph), {"add", "add_1"})

def test_user_supplied_names4(self):
with Graph(graph_name, backend) as test_graph:
res = math_ops.add(x, y, name="add")
unique_name = test_graph.create_unique_name("add", mark_as_used=False)
res = math_ops.add(res, res, name=unique_name)
res = math_ops.add(res, res, name="add")
self.assertEqual(get_node_names(test_graph), {"add", "add_1", "add_2"})

def test_user_supplied_names5(self):
with Graph(graph_name, backend) as test_graph:
unique_name = test_graph.create_unique_name("add", mark_as_used=False)
res = math_ops.add(x, y, name=unique_name)
self.assertEqual(get_node_names(test_graph), {"add"})

if __name__ == "__main__":
unittest.main()

0 comments on commit 1050e19

Please sign in to comment.