Skip to content

Commit

Permalink
Add Task.parameter-sets
Browse files Browse the repository at this point in the history
Fixes #134
  • Loading branch information
akx committed Nov 1, 2023
1 parent 4e16a0f commit 47c256e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/task-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@
base: 10
distribution: uniform
integerify: False
- task:
step: run training
name: parameter sets
type: manual_search
parameters:
- name: A
style: multiple
rules: {}
- name: B
style: multiple
rules: {}
parameter-sets:
- A: 5
B: 6
- A: 8
B: 9
- A: 72
B: 42
9 changes: 9 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ def test_task_additional_fields(task_config: Config):
assert task.optimization_target_metric == "goodness"
assert task.optimization_target_value == 7.2
assert task.type == TaskType.RANDOM_SEARCH


def test_task_parameter_sets(task_config: Config):
task = task_config.tasks["parameter sets"]
assert task.parameter_sets == [
{"A": 5, "B": 6},
{"A": 8, "B": 9},
{"A": 72, "B": 42},
]
7 changes: 7 additions & 0 deletions valohai_yaml/objs/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Task(Item):
step: str
type: TaskType
parameters: list[VariantParameter]
parameter_sets: list[dict[str, Any]]
name: str
execution_count: int | None
execution_batch_size: int | None
Expand All @@ -72,6 +73,7 @@ def __init__(
step: str,
type: TaskType | str | None = None,
parameters: list[VariantParameter] | None = None,
parameter_sets: list[dict[str, Any]] | None = None,
execution_count: int | None = None,
execution_batch_size: int | None = None,
maximum_queued_executions: int | None = None,
Expand All @@ -85,6 +87,9 @@ def __init__(
self.step = step
self.type = TaskType.cast(type)
self.parameters = check_type_and_listify(parameters, VariantParameter)
self.parameter_sets = [
ps for ps in check_type_and_listify(parameter_sets, dict) if ps
]
self.execution_count = execution_count
self.execution_batch_size = execution_batch_size
self.maximum_queued_executions = maximum_queued_executions
Expand All @@ -105,3 +110,5 @@ def parse(cls, data: Any) -> Task:
def lint(self, lint_result: LintResult, context: LintContext) -> None:
context = dict(context, task=self, object_type="task")
lint_expression(lint_result, context, "stop-condition", self.stop_condition)
if self.parameter_sets and self.type != TaskType.MANUAL_SEARCH:
lint_result.add_warning("Parameter sets only make sense with manual search")
5 changes: 5 additions & 0 deletions valohai_yaml/schema/task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ properties:
"$ref": "./variant-param.json"
stop-condition:
type: string
parameter-sets:
type: array
description: Parameter sets for manual search mode.
items:
type: object

0 comments on commit 47c256e

Please sign in to comment.