diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py index ecbeeb7b042..9d7e582c971 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client.py @@ -141,8 +141,7 @@ def tune( parameters: Dict[str, Any], base_image: str = constants.BASE_IMAGE_TENSORFLOW, namespace: Optional[str] = None, - env: Union[Dict[str, str], List[V1EnvVar], None] = None, - env_from: Union[V1EnvFromSource, List[V1EnvFromSource], None] = None, + env_per_trial: Union[Dict[str, str], List[Union[V1EnvVar, V1EnvFromSource]], None] = None, algorithm_name: str = "random", algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None, objective_metric_name: str = None, @@ -175,13 +174,12 @@ def tune( objective function. base_image: Image to use when executing the objective function. namespace: Namespace for the Experiment. - env: Environment variable(s) to be attached to each trial container. You can either specifiy - a list of kubernetes.client.models.V1EnvVar (documented here: - https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) or a dictionary - corresponding to the environment variable name and value pair(s). - env_from: Source(s) of environment variables to be populated in each trial container. You can either specify - a kubernetes.client.models.V1EnvFromSource (documented here: - https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) or a list of such a type. + env_per_trial: Environment variable(s) to be attached to each trial container. + You can specify a dictionary as a mapping object representing the environment variables. + Otherwise, you can specify a list, in which the element can either be a kubernetes.client.models.V1EnvVar (documented here: + https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) + or a kubernetes.client.models.V1EnvFromSource (documented here: + https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) algorithm_name: Search algorithm for the HyperParameter tuning. algorithm_settings: Settings for the search algorithm given. For available fields, check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail. @@ -328,12 +326,15 @@ def tune( requests=resources_per_trial, limits=resources_per_trial, ) - - if isinstance(env, dict): - env = [V1EnvVar(name=str(k), value=str(v)) for k, v in env.items()] - - if isinstance(env_from, V1EnvFromSource): - env_from = [env_from] + + if isinstance(env_per_trial, dict): + env, env_from = [V1EnvVar(name=str(k), value=str(v)) for k, v in env_per_trial.items()] or None, None + + if env_per_trial: + env = [x for x in env_per_trial if isinstance(x, V1EnvVar)] or None + env_from = [x for x in env_per_trial if isinstance(x, V1EnvFromSource)] or None + else: + env, env_from = None, None # Create Trial specification. trial_spec = client.V1Job(