Skip to content

Commit

Permalink
Add is none function (#1757)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 29, 2023
1 parent 4047d99 commit 25f9863
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 9 deletions.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ flask==2.3.2
# via mlflow
flatbuffers==23.5.26
# via tensorflow
flyteidl==1.5.14
flyteidl==1.5.16
# via flytekit
fonttools==4.42.0
# via matplotlib
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def transform_to_conj_expr(
def transform_to_operand(v: Union[Promise, Literal]) -> Tuple[_core_cond.Operand, Optional[Promise]]:
if isinstance(v, Promise):
return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v
if v.scalar.none_type:
return _core_cond.Operand(scalar=v.scalar), None
return _core_cond.Operand(primitive=v.scalar.primitive), None


Expand Down
27 changes: 24 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,26 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr
self._lhs = lhs
if lhs.is_ready:
if lhs.val.scalar is None or lhs.val.scalar.primitive is None:
raise ValueError("Only primitive values can be used in comparison")
union = lhs.val.scalar.union
if union and union.value.scalar:
if union.value.scalar.primitive or union.value.scalar.none_type:
self._lhs = union.value
else:
raise ValueError("Only primitive values can be used in comparison")
else:
raise ValueError("Only primitive values can be used in comparison")
if isinstance(rhs, Promise):
self._rhs = rhs
if rhs.is_ready:
if rhs.val.scalar is None or rhs.val.scalar.primitive is None:
raise ValueError("Only primitive values can be used in comparison")
union = rhs.val.scalar.union
if union and union.value.scalar:
if union.value.scalar.primitive or union.value.scalar.none_type:
self._rhs = union.value
else:
raise ValueError("Only primitive values can be used in comparison")
else:
raise ValueError("Only primitive values can be used in comparison")
if self._lhs is None:
self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None)
if self._rhs is None:
Expand All @@ -163,11 +177,15 @@ def op(self) -> ComparisonOps:
def eval(self) -> bool:
if isinstance(self.lhs, Promise):
lhs = self.lhs.eval()
elif self.lhs.scalar.none_type:
lhs = None
else:
lhs = get_primitive_val(self.lhs.scalar.primitive)

if isinstance(self.rhs, Promise):
rhs = self.rhs.eval()
elif self.rhs.scalar.none_type:
rhs = None
else:
rhs = get_primitive_val(self.rhs.scalar.primitive)

Expand Down Expand Up @@ -351,9 +369,12 @@ def is_(self, v: bool) -> ComparisonExpression:
def is_false(self) -> ComparisonExpression:
return self.is_(False)

def is_true(self):
def is_true(self) -> ComparisonExpression:
return self.is_(True)

def is_none(self) -> ComparisonExpression:
return ComparisonExpression(self, ComparisonOps.EQ, None)

def __eq__(self, other) -> ComparisonExpression: # type: ignore
return ComparisonExpression(self, ComparisonOps.EQ, other)

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected.union_type is None:
if python_val is None and expected and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
18 changes: 15 additions & 3 deletions flytekit/models/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,17 @@ def from_flyte_idl(cls, pb2_object):


class Operand(_common.FlyteIdlEntity):
def __init__(self, primitive=None, var=None):
def __init__(self, primitive=None, var=None, scalar=None):
"""
Defines an operand to a comparison expression.
:param flytekit.models.literals.Primitive primitive:
:param Text var:
:param flytekit.models.literals.Primitive primitive: A primitive value
:param Text var: A variable name
:param flytekit.models.literals.Scalar scalar: A scalar value
"""

self._primitive = primitive
self._var = var
self._scalar = scalar

@property
def primitive(self):
Expand All @@ -160,13 +162,22 @@ def var(self):

return self._var

@property
def scalar(self):
"""
:rtype: flytekit.models.literals.Scalar
"""

return self._scalar

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.condition_pb2.Operand
"""
return _condition.Operand(
primitive=self.primitive.to_flyte_idl() if self.primitive else None,
var=self.var if self.var else None,
scalar=self.scalar.to_flyte_idl() if self.scalar else None,
)

@classmethod
Expand All @@ -176,6 +187,7 @@ def from_flyte_idl(cls, pb2_object):
if pb2_object.HasField("primitive")
else None,
var=pb2_object.var if pb2_object.HasField("var") else None,
scalar=_literals.Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None,
)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.5.14",
"flyteidl>=1.5.16",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ def decompose() -> int:
assert decompose() == 20


def test_condition_is_none():
@task
def return_true() -> typing.Optional[None]:
return None

@workflow
def failed() -> int:
return 10

@workflow
def success() -> int:
return 20

@workflow
def decompose_unary() -> int:
result = return_true()
return conditional("test").if_(result.is_none()).then(success()).else_().then(failed())


def test_subworkflow_condition_serialization():
"""Test that subworkflows are correctly extracted from serialized workflows with condiationals."""

Expand Down
67 changes: 67 additions & 0 deletions tests/flytekit/unit/models/core/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,73 @@ def test_branch_node():
assert bn.if_else.case.then_node == obj


def test_branch_node_with_none():
nm = _get_sample_node_metadata()
task = _workflow.TaskNode(reference_id=_generic_id)
bd = _literals.BindingData(scalar=_literals.Scalar(none_type=_literals.Void()))
lt = _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99)))
bd2 = _literals.BindingData(
scalar=_literals.Scalar(
union=_literals.Union(value=lt, stored_type=_types.LiteralType(_types.SimpleType.INTEGER))
)
)
binding = _literals.Binding(var="myvar", binding=bd)
binding2 = _literals.Binding(var="myothervar", binding=bd2)

obj = _workflow.Node(
id="some:node:id",
metadata=nm,
inputs=[binding, binding2],
upstream_node_ids=[],
output_aliases=[],
task_node=task,
)

bn = _workflow.BranchNode(
_workflow.IfElseBlock(
case=_workflow.IfBlock(
condition=_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
then_node=obj,
),
other=[
_workflow.IfBlock(
condition=_condition.BooleanExpression(
conjunction=_condition.ConjunctionExpression(
_condition.ConjunctionExpression.LogicalOperator.AND,
_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
)
),
then_node=obj,
)
],
else_node=obj,
)
)

bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl())
assert bn == bn2
assert bn.if_else.case.then_node == obj


def test_task_node_overrides():
overrides = _workflow.TaskNodeOverrides(
Resources(
Expand Down

0 comments on commit 25f9863

Please sign in to comment.