Skip to content

Commit

Permalink
Merge pull request #107 from valohai/memona/fix/include_pipeline_para…
Browse files Browse the repository at this point in the history
…meters_in_payload

Pipeline Parameter Conversion
  • Loading branch information
tokkoro authored Aug 17, 2023
2 parents 408458a + de5b798 commit 6de715b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/pipeline-with-parameters-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
targets:
- train.parameter.id
- train_parallel.parameter.id
default: 123
nodes:
- name: train
step: train_step
Expand Down
21 changes: 21 additions & 0 deletions tests/test_pipeline_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ def test_pipeline_conversion_override_inputs(pipeline_overridden_config: Config)
assert overridden["template"]["inputs"].get('training-images', [])[0] == 'overridden node image'
assert len(overridden["template"]["inputs"].get('training-images', [])) == 1
assert len(overridden["template"]["parameters"].items()) == 3


def test_pipeline_parameter_conversion(pipeline_with_parameters_config):
parameter_name = "id"
for _name, pipe in pipeline_with_parameters_config.pipelines.items():
result = PipelineConverter(
config=pipeline_with_parameters_config,
commit_identifier="latest",
).convert_pipeline(pipe)
assert isinstance(result["parameters"], dict)
assert result["parameters"][parameter_name]

parameter = result["parameters"][parameter_name]
assert parameter["config"]["targets"]
assert "target" not in parameter["config"]
assert isinstance(parameter["config"]["targets"], list)

# When pipeline parameter has no default value, the expression should be empty
parameter_config = next(param for param in pipe.parameters if param.name == parameter_name)
expression_value = parameter_config.default if parameter_config.default else ""
assert parameter["expression"] == expression_value
28 changes: 26 additions & 2 deletions valohai_yaml/pipelines/conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import sys
from typing import Any, Dict, List, Union

from valohai_yaml.objs import (
Expand All @@ -7,13 +7,26 @@
ExecutionNode,
Node,
Pipeline,
PipelineParameter,
TaskNode,
)
from valohai_yaml.objs.pipelines.override import Override

ConvertedObject = Dict[str, Any]


if sys.version_info >= (3, 8):
from typing import TypedDict
class ConvertedPipeline(TypedDict):
"""TypedDict for converted Pipeline object."""

edges: List[ConvertedObject]
nodes: List[ConvertedObject]
parameters: Dict[str, ConvertedObject]
else:
ConvertedPipeline = ConvertedObject


class PipelineConverter:
"""Converts pipeline objects to Valohai API payloads."""

Expand All @@ -26,10 +39,21 @@ def __init__(
self.config = config
self.commit_identifier = commit_identifier

def convert_pipeline(self, pipeline: Pipeline) -> Dict[str, List[ConvertedObject]]:
def convert_pipeline(self, pipeline: Pipeline) -> ConvertedPipeline:
return {
"edges": [edge.get_expanded() for edge in pipeline.edges],
"nodes": [self.convert_node(node) for node in pipeline.nodes],
"parameters": {
parameter.name: self.convert_parameter(parameter)
for parameter in pipeline.parameters
},
}

def convert_parameter(self, parameter: PipelineParameter) -> ConvertedObject:
"""Convert a pipeline parameter to a config-expression payload."""
return {
"config": {**parameter.serialize()},
"expression": parameter.default if parameter.default else "",
}

def convert_node(self, node: Node) -> ConvertedObject:
Expand Down

0 comments on commit 6de715b

Please sign in to comment.