Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Sep 10, 2024
1 parent cc8f250 commit fa377fa
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 51 deletions.
7 changes: 1 addition & 6 deletions examples/jax/cpu-demo/demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@ spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: Never
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-worker
image: sandipanify/jaxgoogle
command: ["python", "train.py"]
args:
- --num_processes="2"
- --job_name=jaxjob-simple
- --sub_domain=training-operator
- --coordinator_port="6666"
ports:
- containerPort: 6666
imagePullPolicy: Always
6 changes: 5 additions & 1 deletion examples/jax/cpu-demo/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# example ref:
# https://jax.readthedocs.io/en/latest/multi_process.html#running-multi-process-computations
# https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/tutorials-and-examples/gpu-examples/a100-jax/train.py # noqa

import os
import socket
import time
Expand Down Expand Up @@ -40,7 +44,7 @@ def _get_coordinator_ip_address(job_name, sub_domain):

def _main(argv):

process_id = int(os.getenv("JOB_COMPLETION_INDEX"))
process_id = int(os.getenv("PROCESS_ID"))
num_processes = FLAGS.num_processes
coordinator_address = _get_coordinator_ip_address(FLAGS.job_name, FLAGS.sub_domain)
coordinator_address = f"{coordinator_address}:{FLAGS.coordinator_port}"
Expand Down
17 changes: 5 additions & 12 deletions pkg/controller.v1/jax/envvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)
Expand All @@ -28,11 +29,7 @@ type EnvVarGenerator interface {
Generate(job *kubeflowv1.JAXJob) ([]corev1.EnvVar, error)
}

func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
jaxjob, ok := obj.(*kubeflowv1.JAXJob)
if !ok {
return fmt.Errorf("%+v is not a type of JAXJob", obj)
}
func setPodEnv(jaxjob *kubeflowv1.JAXJob, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {

coordinatorAddr := replicaName(jaxjob.Name, kubeflowv1.JAXJobReplicaTypeWorker, 0)

Expand All @@ -44,10 +41,6 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
totalReplicas := getTotalReplicas(jaxjob)

for i := range podTemplateSpec.Spec.Containers {
// Initialize the environment variables.
if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
podTemplateSpec.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
}

rank, err := strconv.Atoi(index)
if err != nil {
Expand Down Expand Up @@ -80,10 +73,10 @@ func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype,
return nil
}

func getTotalReplicas(job *kubeflowv1.JAXJob) int32 {
jobReplicas := int32(0)
func getTotalReplicas(job *kubeflowv1.JAXJob) int {
jobReplicas := 0
for _, r := range job.Spec.JAXReplicaSpecs {
jobReplicas += *r.Replicas
jobReplicas += int(ptr.Deref[int32](r.Replicas, 0))
}
return jobReplicas
}
Expand Down
51 changes: 27 additions & 24 deletions pkg/controller.v1/jax/jaxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/source"
Expand All @@ -62,11 +61,11 @@ const (
// NewReconciler creates a JAXJob Reconciler
func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *JAXJobReconciler {
r := &JAXJobReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
client: mgr.GetClient(),
scheme: mgr.GetScheme(),
recorder: mgr.GetEventRecorderFor(controllerName),
apiReader: mgr.GetAPIReader(),
Log: log.Log,
log: ctrl.Log.WithName(controllerName),
}

// Create clients
Expand Down Expand Up @@ -96,9 +95,9 @@ func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSched
// JAXJobReconciler reconciles a JAXJob object
type JAXJobReconciler struct {
common.JobController
client.Client
Scheme *runtime.Scheme
Log logr.Logger
client client.Client
scheme *runtime.Scheme
log logr.Logger
recorder record.EventRecorder
apiReader client.Reader
}
Expand All @@ -108,7 +107,6 @@ type JAXJobReconciler struct {
//+kubebuilder:rbac:groups=kubeflow.org,resources=jaxjobs/finalizers,verbs=update
//+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete
//+kubebuilder:rbac:groups=autoscaling,resources=horizontalpodautoscalers,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=scheduling.x-k8s.io,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups="",resources=events,verbs=get;list;watch;create;update;patch;delete
Expand All @@ -122,16 +120,17 @@ type JAXJobReconciler struct {
// For more details, check Reconcile and its Result here:
// - https://pkg.go.dev/sigs.k8s.io/[email protected]/pkg/reconcile
func (r *JAXJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
_ = log.FromContext(ctx)
logger := r.Log.WithValues(kubeflowv1.JAXJobSingular, req.NamespacedName)

jaxjob := &kubeflowv1.JAXJob{}
err := r.Get(ctx, req.NamespacedName, jaxjob)
err := r.client.Get(ctx, req.NamespacedName, jaxjob)
if err != nil {
logger.Info(err.Error(), "unable to fetch JAXJob", req.NamespacedName.String())
return ctrl.Result{}, client.IgnoreNotFound(err)
}

// log := ctrl.LoggerFrom(ctx).WithValues("jaxjob", klog.KObj(&jaxjob))
// ctrl.LoggerInto(ctx, log)
// log.V(2).Info("Reconciling JAXJob")

// Check if reconciliation is needed
jobKey, err := common.KeyFunc(jaxjob)
if err != nil {
Expand All @@ -142,18 +141,18 @@ func (r *JAXJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr
needReconcile := util.SatisfiedExpectations(r.Expectations, jobKey, replicaTypes)

if !needReconcile || jaxjob.GetDeletionTimestamp() != nil {
logger.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
r.log.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
"sync", needReconcile, "deleted", jaxjob.GetDeletionTimestamp() != nil)
return ctrl.Result{}, nil
}

// Set default priorities to jax job
r.Scheme.Default(jaxjob)
r.scheme.Default(jaxjob)

// Use common to reconcile the job related pod and service
err = r.ReconcileJobs(jaxjob, jaxjob.Spec.JAXReplicaSpecs, jaxjob.Status, &jaxjob.Spec.RunPolicy)
if err != nil {
logger.Error(err, "Reconcile JAXJob error")
r.log.Error(err, "Reconcile JAXJob error")
return ctrl.Result{}, err
}
t, err := util.DurationUntilExpireTime(&jaxjob.Spec.RunPolicy, jaxjob.Status)
Expand Down Expand Up @@ -247,7 +246,7 @@ func (r *JAXJobReconciler) GetFrameworkName() string {

func (r *JAXJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
job := &kubeflowv1.JAXJob{}
err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
err := r.client.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
if err != nil {
if errors.IsNotFound(err) {
logrus.Error(err, "jax job not found", "namespace", namespace, "name", name)
Expand Down Expand Up @@ -283,7 +282,7 @@ func (r *JAXJobReconciler) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error)
// List all pods to include those that don't match the selector anymore
// but have a ControllerRef pointing to this controller.
podlist := &corev1.PodList{}
err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
err = r.client.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
if err != nil {
return nil, err
}
Expand All @@ -300,7 +299,7 @@ func (r *JAXJobReconciler) GetServicesForJob(obj interface{}) ([]*corev1.Service
// List all pods to include those that don't match the selector anymore
// but have a ControllerRef pointing to this controller.
serviceList := &corev1.ServiceList{}
err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
err = r.client.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
if err != nil {
return nil, err
}
Expand All @@ -314,7 +313,7 @@ func (r *JAXJobReconciler) DeleteJob(job interface{}) error {
if !ok {
return fmt.Errorf("%+v is not a type of JAXJob", job)
}
if err := r.Delete(context.Background(), jaxjob); err != nil {
if err := r.client.Delete(context.Background(), jaxjob); err != nil {
r.recorder.Eventf(jaxjob, corev1.EventTypeWarning, control.FailedDeletePodReason, "Error deleting: %v", err)
logrus.Error(err, "failed to delete job", "namespace", jaxjob.Namespace, "name", jaxjob.Name)
return err
Expand Down Expand Up @@ -434,10 +433,10 @@ func (r *JAXJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus
jaxjob.Status = *jobStatus.DeepCopy()
}

result := r.Status().Update(context.Background(), jaxjob)
result := r.client.Status().Update(context.Background(), jaxjob)

if result != nil {
r.Log.WithValues("jaxjob", types.NamespacedName{
r.log.WithValues("jaxjob", types.NamespacedName{
Namespace: jaxjob.GetNamespace(),
Name: jaxjob.GetName(),
})
Expand All @@ -449,7 +448,11 @@ func (r *JAXJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus

// SetClusterSpec sets the cluster spec and init container for the pod
func (r *JAXJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
if err := setPodEnv(job, podTemplate, rtype, index); err != nil {
jaxjob, ok := job.(*kubeflowv1.JAXJob)
if !ok {
return fmt.Errorf("%+v is not a type of JAXJob", job)
}
if err := setPodEnv(jaxjob, podTemplate, rtype, index); err != nil {
return err
}
return nil
Expand All @@ -465,7 +468,7 @@ func (r *JAXJobReconciler) GetDefaultContainerPortName() string {

func (r *JAXJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec,
rtype kubeflowv1.ReplicaType, index int) bool {
return false
return index == 0
}

// onOwnerCreateFunc modify creation condition.
Expand All @@ -475,7 +478,7 @@ func (r *JAXJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
if !ok {
return true
}
r.Scheme.Default(jaxjob)
r.scheme.Default(jaxjob)
msg := fmt.Sprintf("JAXJob %s is created.", e.Object.GetName())
logrus.Info(msg)
trainingoperatorcommon.CreatedJobsCounterInc(jaxjob.Namespace, r.GetFrameworkName())
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller.v1/jax/jaxjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ var _ = Describe("JAXJob controller", func() {
Type: kubeflowv1.JobCreated,
Status: corev1.ConditionTrue,
Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason),
Message: fmt.Sprintf("JAXJob %s is created.", name),
Message: fmt.Sprintf("JAXJob %s is created.", ns.Name+"/"+name),
},
{
Type: kubeflowv1.JobRunning,
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/kubeflow/training/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
JAXJOB_PLURAL = "jaxjobs"
JAXJOB_CONTAINER = "jax"
JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower()
JAXJOB_BASE_IMAGE = "docker.io/sandipanify/jaxgloo:latest"

# Dictionary to get plural, model, and container for each Job kind.
JOB_PARAMETERS = {
Expand Down Expand Up @@ -181,7 +182,7 @@
"model": JAXJOB_MODEL,
"plural": JAXJOB_PLURAL,
"container": JAXJOB_CONTAINER,
"base_image": "TODO",
"base_image": "JAXJOB_BASE_IMAGE",
},
}

Expand Down
6 changes: 0 additions & 6 deletions sdk/python/test/e2e/test_e2e_jaxjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,5 @@ def generate_container() -> V1Container:
name=CONTAINER_NAME,
image="docker.io/sandipanify/jaxgoogle:latest",
command=["python", "train.py"],
args=[
"--num_processes=2",
"--job_name=example-job",
"--sub_domain=training-operator",
"--cooordinator_port=6666",
],
resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
)

0 comments on commit fa377fa

Please sign in to comment.