Skip to content

Commit

Permalink
fix: Move sagemaker-mlflow to extras (#4903)
Browse files Browse the repository at this point in the history
Co-authored-by: Erick Benitez-Ramos <[email protected]>
  • Loading branch information
ryansteakley and benieric authored Nov 5, 2024
1 parent f45e600 commit 292a00d
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion hatch_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def read_feature_deps(feature):

optional_dependencies = {"all": []}

for feature in ("feature-processor", "huggingface", "local", "scipy"):
for feature in ("feature-processor", "huggingface", "local", "scipy", "sagemaker-mlflow"):
dependencies = read_feature_deps(feature)
optional_dependencies[feature] = dependencies
optional_dependencies["all"].extend(dependencies)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ dependencies = [
"PyYAML~=6.0",
"requests",
"sagemaker-core>=1.0.0,<2.0.0",
"sagemaker-mlflow",
"schema",
"smdebug_rulesconfig==1.0.1",
"tblib>=1.7.0,<4",
Expand Down
1 change: 1 addition & 0 deletions requirements/extras/sagemaker-mlflow_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sagemaker-mlflow>=0.1.0
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ huggingface_hub>=0.23.4
uvicorn>=0.30.1
fastapi>=0.111.0
nest-asyncio
sagemaker-mlflow>=0.1.0
11 changes: 8 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline

from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1374,8 +1373,14 @@ def fit(
forward_to_mlflow_tracking_server = True
if wait:
self.latest_training_job.wait(logs=logs)
if forward_to_mlflow_tracking_server:
log_sagemaker_job_to_mlflow(self.latest_training_job.name)
try:
if forward_to_mlflow_tracking_server:
from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow

log_sagemaker_job_to_mlflow(self.latest_training_job.name)
except ImportError:
if forward_to_mlflow_tracking_server:
raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed")

def _compilation_job_name(self):
"""Placeholder docstring"""
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/mlflow/forward_sagemaker_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import re
from typing import Set, Tuple, List, Dict, Generator
import boto3
import mlflow

try:
import mlflow
except ImportError:
raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed.")
from mlflow import MlflowClient
from mlflow.entities import Metric, Param, RunTag

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5924,7 +5924,7 @@ def test_estimator_get_app_url_fail(sagemaker_session):
assert "does not support URL retrieval." in str(error)


@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow")
@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow")
def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
Expand All @@ -5943,7 +5943,7 @@ def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
mock_log_to_mlflow.assert_called_once()


@patch("sagemaker.estimator.log_sagemaker_job_to_mlflow")
@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow")
def test_no_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
Expand Down

0 comments on commit 292a00d

Please sign in to comment.