Skip to content

Commit

Permalink
add env & env_from spec; add missing dependencies for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shipengcheng1230 committed Oct 26, 2023
1 parent d2e311f commit 9a02b0e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmd/suggestion/optuna/v1beta1/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ grpcio>=1.41.1
protobuf>=3.19.5, <=3.20.3
googleapis-common-protos==1.53.0
optuna>=3.0.0
cmaes>=0.10.0
13 changes: 13 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from kubeflow.katib.constants import constants
from kubeflow.katib.utils import utils
from kubernetes import client, config
from kubernetes.client.models import V1EnvVar, V1EnvFromSource, V1ConfigMapEnvSource


class KatibClient(object):
Expand Down Expand Up @@ -140,6 +141,8 @@ def tune(
parameters: Dict[str, Any],
base_image: str = constants.BASE_IMAGE_TENSORFLOW,
namespace: Optional[str] = None,
env: Union[Dict[str, str], List[V1EnvVar], None] = None,
env_from: Union[V1EnvFromSource, List[V1EnvFromSource], None] = None,
algorithm_name: str = "random",
algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None,
objective_metric_name: str = None,
Expand Down Expand Up @@ -172,6 +175,8 @@ def tune(
objective function.
base_image: Image to use when executing the objective function.
namespace: Namespace for the Experiment.
env: Environment variable(s) to be attached to each trial container.
env_from: Source(s) of environment variables to be populated in each trial container.
algorithm_name: Search algorithm for the HyperParameter tuning.
algorithm_settings: Settings for the search algorithm given.
For available fields, check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail.
Expand Down Expand Up @@ -319,6 +324,12 @@ def tune(
limits=resources_per_trial,
)

if isinstance(env, dict):
env = [V1EnvVar(name=str(k), value=str(v)) for k, v in env.items()]

if isinstance(env_from, V1EnvFromSource):
env_from = [env_from]

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand All @@ -336,6 +347,8 @@ def tune(
image=base_image,
command=["bash", "-c"],
args=[exec_script],
env=env,
env_from=env_from,
resources=resources_per_trial,
)
],
Expand Down

0 comments on commit 9a02b0e

Please sign in to comment.