Skip to content

Commit

Permalink
set logger url
Browse files Browse the repository at this point in the history
  • Loading branch information
khintz committed Dec 9, 2024
1 parent 010f716 commit 2620bd1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
25 changes: 18 additions & 7 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@


class CustomMLFlowLogger(pl.loggers.MLFlowLogger):
"""
Custom MLFlow logger that adds functionality not present in the default
"""

def __init__(self, experiment_name, tracking_uri):
super().__init__(
experiment_name=experiment_name, tracking_uri=tracking_uri
Expand All @@ -38,9 +42,22 @@ def __init__(self, experiment_name, tracking_uri):

@property
def save_dir(self):
"""
Returns the directory where the MLFlow artifacts are saved
"""
return "mlruns"

def log_image(self, key, images, step=None):
"""
Log a matplotlib figure as an image to MLFlow
key: str
Key to log the image under
images: list
List of matplotlib figures to log
step: Union[int, None]
Step to log the image under. If None, logs under the key directly
"""
# Third-party
from PIL import Image

Expand All @@ -61,7 +78,7 @@ def log_model(self, data_module, model):
with torch.no_grad():
model_output = model.common_step(input_example)[
0
] # expects batch, returns tuple (prediction, target, pred_std, _)
] # common_step returns tuple (prediction, target, pred_std, _)

log_model_input_example = {
name: tensor.cpu().numpy()
Expand All @@ -81,8 +98,6 @@ def log_model(self, data_module, model):
signature=signature,
)

# validate_serving_input(model_uri, validate_example)

def create_input_example(self, data_module):

if data_module.val_dataset is None:
Expand Down Expand Up @@ -405,9 +420,5 @@ def main(input_args=None):
else:
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)

# Log model. TODO: only log for mlflow
training_logger.log_model(data_module, model)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_config_serialization(state_weighting_config):
kind: mdp
config_path: ""
training:
logger: wandb
logger_url: https://wandb.ai
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
weights:
Expand Down

0 comments on commit 2620bd1

Please sign in to comment.