-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for mlflow #77
Open
khintz
wants to merge
305
commits into
mllam:main
Choose a base branch
from
khintz:feat/mlflow
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 250 commits
Commits
Show all changes
305 commits
Select commit
Hold shift + click to select a range
6685e94
bugfixes
sadamov 6423fdf
pre_commits
sadamov 59c4947
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov 4e457ed
config.py is ready for danra
sadamov adc592f
streamlined multi-zarr workflow
sadamov a7bea6b
xarray zarr based data normalization
sadamov 1f7cbe8
adjusted pre-processing scripts to new data config workflow
sadamov e328152
plotting update with latest get_xy() function
sadamov cb85cda
making data config more modular
sadamov eb8c6fb
removing boundaries for now
sadamov 0cfbb33
small updates
sadamov 59d0c8a
improved stats and units retrieval
sadamov 2f6a87a
add GPU-based runner on cirun.io
leifdenby 668dd81
improved zarr-based normalization
sadamov 143cf2a
pdm install with cpu torch
leifdenby b760915
ensure exec in pdm venv
leifdenby 7797cef
ensure exec in pdm venv
leifdenby e689650
check version #2
leifdenby fb8ef23
check version no 3
leifdenby 51b0a0b
check versions
leifdenby 374d032
merge main
sadamov 8fa3ca7
Introduced datetime forcing calculation as seperate script
sadamov a748903
Fixed order of y and x dims to adhere to #52
sadamov 70425ee
fix for pip install
leifdenby 60110f6
switch cirun instance type
leifdenby 6fff3fc
install py39 on cirun runner
leifdenby 74b4a10
cleanup: boundary_mask, zarr-opening, utils
sadamov 0a041d1
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov 8054e9e
change ami image to gpu
leifdenby 39fbf3a
Merge remote-tracking branch 'upstream/main' into maint/deps-in-pypro…
leifdenby 97aeb2e
use cheaper gpu instance
leifdenby 425123c
adapted tests for zarr-analysis data
sadamov 4dcf671
Readme adapted for yaml zarr analysis workflow
sadamov 6d384f0
samller bugfixes and improvements
sadamov 12ff4f2
Added fixed data config file for testing on Danra
sadamov 03f7769
reducing runtime of tests with smaller sample
sadamov 26f069c
download danra data for test and example (streaming not possible)
sadamov 1f1cbcc
bugfixes after real-life testcase
sadamov b369306
Merge remote-tracking branch 'origin/main' into feature_dataset_yaml
sadamov 0cdc361
organize .zarr in /data
sadamov 23ca7b3
cleanup
sadamov 81422f1
linter
sadamov 124541b
static dataset doesn't have time dim
sadamov 6140fdb
making two complex functions more modular
sadamov db6a912
chunk dataset by time
sadamov 1aaa8dc
create list first for performance
sadamov 81856b2
converting to_array is very slow
sadamov b3da818
allow for forcings to not be normalized
sadamov 7ee5398
allow non_normalized_vars to be null
sadamov 4782103
fixed coastlines using new xy_extent function
sadamov e0ffc5b
Some projections return inverted axes (rotatedPole)
sadamov c1f43b7
Docstrings added
sadamov 21fd929
wip
leifdenby c52f98e
npy mllam nearly done
leifdenby 80f3639
minor adjustment
leifdenby 048f8c6
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby 5aaa239
add pooch and tweak pip cicd testing
leifdenby 66c3b03
combine cicd tests with caching
leifdenby 8566b8f
linting
leifdenby 29bd9e5
add pyg dep
leifdenby bc7f028
set cirun aws region to frankfurt
leifdenby 2070166
adapt image
leifdenby e4e86e5
set image
leifdenby 1fba8fe
try different image
leifdenby 02b77cf
add pooch to cicd
leifdenby b481929
add pdm gpu test
leifdenby bcec472
start work on readme
leifdenby c5beec9
Merge branch 'maint/deps-in-pyproject-toml' into datastore
leifdenby e89facc
Merge branch 'main' into maint/refactor-as-package
leifdenby 0b5687a
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby 095fdbc
turn meps testdata download into pytest fixture
leifdenby 49e9bfe
adapt README for package
leifdenby 12cc02b
remove pdm cicd test (will be in separate PR)
leifdenby b47f50b
remove pdm in gitignore
leifdenby 90d99ca
remove pdm and pyproject files (will be sep PR)
leifdenby a91eaaa
add pyproject.toml from main
leifdenby 5508cea
clean out tests
leifdenby 5c623c3
fix linting
leifdenby 08ec168
add cli entrypoints import test
leifdenby d9cf7ba
Merge branch 'maint/refactor-as-package' into datastore
leifdenby 3954f04
tweak cicd pytest execution
leifdenby f99fdce
Merge branch 'maint/refactor-as-package' into datastore
leifdenby db9d96f
Update tests/test_mllam_dataset.py
leifdenby 3c864b2
grid-shape ok
leifdenby 1f54b0e
get_vars_names and units
leifdenby 9b88160
get_vars_names and units 2
leifdenby a9fdad5
test for stats
leifdenby 555154f
get_dataarray test
leifdenby 8b8a77e
get_dataarray test
leifdenby 41f11cd
boundary_mask
leifdenby a17de0f
get_xy
leifdenby 0a38a7d
remove TrainingSample dataclass
leifdenby f65f6b5
test for WeatherDataset.__getitem__
leifdenby a35100e
test for graph creation
leifdenby cfb0618
more graph creation tests
leifdenby 8698719
check for consistency of num features across splits
leifdenby 3381404
test for single batch from mllam through model
leifdenby 2a6796c
Add init files to expose classes in editable package
joeloskarsson 8f4e0e0
Linting
joeloskarsson e657abb
working training_step with datastores!
effc99b
remove superfluous tests
a047026
fix for dataset length
d2c62ed
step length should be int
58f5d99
step length should be int
64d43a6
training working with mllam datastore!
07444f8
adapt neural_lam.train_model for datastores
d1b6fc1
fixes for npy
6fe19ac
npyfiles datastore complete
leifdenby fe65a4d
cleanup for datastore examples
leifdenby e533794
training on ohm with danra!
640ac05
use mllam-data-prep v0.2.0
0f16f13
remove py3.12 from pre-commit
724548e
cleanup
a1b2037
all tests passing!
e35958f
use mllam-data-prep v0.3.0
8b92318
delete requirements.txt
658836a
remove .DS_Store
421efed
use tmate in gpu pdm cicd
05f1e9f
remove requirements
3afe0e4
update pdm gpu cicd setup to pdm venv on nvme drive
f3d028b
don't try to use pdm venv in-project
2c35662
remove tmate
5f30255
update README with install instructions
b2b5631
changelog
c8ae829
update ci/cd badges to include gpu + gpu
e7cf2c0
Merge pull request #1 from mllam/package_inits
leifdenby 0b72e9d
add pyproject-flake8 to precommit config
190d1de
use Flake8-pyproject instead
791af0a
update README
58fab84
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
dbe2e6d
Merge branch 'maint/refactor-as-package' into maint/deps-in-pyproject…
eac6e35
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
799d55e
linting fixes
57bbb81
train only 1 epoch in cicd and print to stdout
a955cee
log datastore config
0a79c74
cleanup doctrings
9f3c014
Merge branch 'maint/refactor-as-package' into datastore
leifdenby 41364a8
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
leifdenby 3422298
update changelog
leifdenby 689ef69
move dev deps optional dependencies group
leifdenby 9a0d538
update cicd tests to install dev deps
leifdenby bddfcaf
update readme with new dev deps group
leifdenby b96cfdc
quote the skip step the install readme
leifdenby 2600dee
remove unused files
leifdenby 65a8074
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby 6adf6cc
revert to line length of 80
leifdenby 46b37f8
revert docstring formatting changes
leifdenby 3cd0f8b
pin numpy to <2.0.0
leifdenby 826270a
Merge branch 'maint/deps-in-pyproject-toml' into feat/datastores
leifdenby 4ba22ea
Merge branch 'main' into feat/datastores
leifdenby 1f661c6
fix flake8 linting errors
leifdenby 4838872
Update neural_lam/weather_dataset.py
leifdenby b59e7e5
Update neural_lam/datastore/multizarr/create_normalization_stats.py
leifdenby 75b1fe7
Update neural_lam/datastore/npyfiles/store.py
leifdenby 7e736cb
Update neural_lam/datastore/npyfiles/store.py
leifdenby 613a7e2
Update neural_lam/datastore/npyfiles/store.py
leifdenby 65e199b
Update tests/test_training.py
leifdenby 4435e26
Update tests/test_datasets.py
leifdenby 4693408
Update README.md
leifdenby 2dfed2c
update README
leifdenby c3d033d
Merge branch 'main' of https://github.com/mllam/neural-lam into feat/…
leifdenby 4a70268
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby 66c663f
column_water -> open_water_fraction
leifdenby 11a7978
fix linting
leifdenby a41c314
static data same for all splits
leifdenby 6f1efd6
forcing_window_size from args
leifdenby bacb9ec
Update neural_lam/datastore/base.py
leifdenby 4a9db4e
only use first ensemble member in datastores
leifdenby 4fc2448
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby bcaa919
Update neural_lam/datastore/base.py
leifdenby 90bc594
Update neural_lam/datastore/base.py
leifdenby 5bda935
Update neural_lam/datastore/base.py
leifdenby 8e7931d
remove all multizarr functionality
leifdenby 6998683
cleanup and test fixes for recent changes
leifdenby c415008
Merge branch 'feat/datastores' of https://github.com/leifdenby/neural…
leifdenby 735d324
fix linting
leifdenby 5f2d919
remove multizar example files
leifdenby 5263d2c
normalization -> standardization
leifdenby ba1bec3
fix import for tests
leifdenby d04d15e
Update neural_lam/datastore/base.py
leifdenby 743d7a1
fix coord issues and add datastore example plotting cli
leifdenby ac10d7d
add lru_cache to get_xy_extent
leifdenby bf8172a
MLLAMDatastore -> MDPDatastore
leifdenby 90ca400
missed renames for MDPDatastore
leifdenby 154139d
update graph plot for datastores
leifdenby 50ee0b0
use relative import
leifdenby 7dfd570
add long_names and refactor npyfiles create weights
leifdenby 2b45b5a
Update neural_lam/weather_dataset.py
leifdenby aee0b1c
Update neural_lam/weather_dataset.py
leifdenby 8453c2b
Update neural_lam/models/ar_model.py
leifdenby 7f32557
Update neural_lam/weather_dataset.py
leifdenby 67998b8
read projection from datastore config extra section
leifdenby ac7e46a
NpyFilesDatastore -> NpyFilesDatastoreMEPS
leifdenby b7bf506
revert tp training with 1 AR step by default
leifdenby 5df2ecf
add missing kwarg to BaseHiGraphModel.__init__
leifdenby d4d438f
add missing kwarg to HiLAM.__init__
leifdenby 1889771
add missing kwarg to HiLAMParallel
leifdenby 2c3bbde
check that for enough forecast steps given ar_steps
leifdenby f0a151b
remove numpy<2.0.0 version cap
leifdenby f3566b0
tweak print statement working in mdp
dba94b3
fix missed removed argument from cli
bca1482
remove wandb config log comment, we log now
fc973c4
ensure loading from checkpoint during train possible
9fcf06e
get step_length from datastore in plot_error_map
leifdenby 2bbe666
remove step_legnth attr in ARModel
leifdenby b41ed2f
remove unused obs_mask arg for vis.plot_prediction
leifdenby 7e46194
ensure no reference to multizarr "data_config"
leifdenby b57bc7a
introduce neural-lam config
leifdenby 2b30715
include meps neural-lam config example
leifdenby 8e7b2e6
fix extra space typo in BaseDatastore
leifdenby e0300fb
add check and print of train/test/val split in MDPDatastore
leifdenby a921e35
add experimental mlflow server support
leifdenby 0f30259
more fixes for mlflow logging support
leifdenby 3fbe2d0
Make wandb work again with pytorch_lightning.logger
khintz e0284a8
upload of artifact to mlflow works, but instantiates a new experiment
khintz 7eed79b
make mlflow use same experiment run id as pl.logger.MLFlowLogger
khintz 27408f2
logger artifact working for both wandb and mlflow
khintz e61a9e7
support mlflow system metrics logging
khintz b53bab5
support model logging for mlflow
khintz de27e9a
log model
khintz 89d8cde
test system metrics
khintz 54c7ca7
make mlflow work also for eval mode
khintz a47de0c
dummy prints to identify workflow
khintz 10a4494
update mlflow on eval mode
khintz 427a4b1
Merge branch 'main' into feat/mlflow
khintz 78e874d
inspect plot routines
khintz 5904cbe
identified issue, cleanup next
leifdenby efe0302
use xarray plot only
leifdenby a489c2e
don't reraise
leifdenby 242d08b
remove debug plot
leifdenby c1f706c
remove extent calc used in diagnosing issue
leifdenby 88ec9dc
Test order of dimension in eval plots
khintz d367cdb
Merge branch 'fix/eval-vis-plots' into feat/mlflow
khintz 90f8918
fix tensors on cpu and plot time index
khintz 53f0ea4
restore tests/test_datasets.py
khintz cfc249f
cleaning up with focus on linting
khintz b218c8b
update tests
khintz 1f1aed8
use correct data module for input example
khintz f3abd47
Merge branch 'main' into feat/mlflow
khintz 47932b5
clean log model function
khintz 98dc5c4
Merge branch 'main' into feat/mlflow
khintz 64971ae
revert bad merge
khintz 010f716
remove unused init for datastore
khintz 2620bd1
set logger url
khintz 75a39e6
change type of default logger_url in config
khintz 9d27a4c
linting
khintz 8f42cd1
fix log_image issue in tests
khintz b5ebe6f
add entry to changelog
khintz ae69f3f
remove artifacts from earlier merging/rebase
khintz 6e16035
catch error when aws credentials not set
khintz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,15 @@ | |
from argparse import ArgumentParser | ||
|
||
# Third-party | ||
import mlflow | ||
|
||
# for logging the model: | ||
import mlflow.pytorch | ||
import pytorch_lightning as pl | ||
import torch | ||
from lightning_fabric.utilities import seed | ||
from loguru import logger | ||
from mlflow.models import infer_signature | ||
|
||
# Local | ||
from . import utils | ||
|
@@ -23,6 +28,110 @@ | |
} | ||
|
||
|
||
class CustomMLFlowLogger(pl.loggers.MLFlowLogger): | ||
""" | ||
Custom MLFlow logger that adds functionality not present in the default | ||
""" | ||
|
||
def __init__(self, experiment_name, tracking_uri): | ||
super().__init__( | ||
experiment_name=experiment_name, tracking_uri=tracking_uri | ||
) | ||
mlflow.start_run(run_id=self.run_id, log_system_metrics=True) | ||
mlflow.log_param("run_id", self.run_id) | ||
|
||
@property | ||
def save_dir(self): | ||
""" | ||
Returns the directory where the MLFlow artifacts are saved | ||
""" | ||
return "mlruns" | ||
|
||
def log_image(self, key, images, step=None): | ||
""" | ||
Log a matplotlib figure as an image to MLFlow | ||
|
||
key: str | ||
Key to log the image under | ||
images: list | ||
List of matplotlib figures to log | ||
step: Union[int, None] | ||
Step to log the image under. If None, logs under the key directly | ||
""" | ||
# Third-party | ||
from PIL import Image | ||
|
||
if step is not None: | ||
key = f"{key}_{step}" | ||
|
||
# Need to save the image to a temporary file, then log that file | ||
# mlflow.log_image, should do this automatically, but is buggy | ||
temporary_image = f"{key}.png" | ||
images[0].savefig(temporary_image) | ||
|
||
img = Image.open(temporary_image) | ||
mlflow.log_image(img, f"{key}.png") | ||
|
||
def log_model(self, data_module, model): | ||
input_example = self.create_input_example(data_module) | ||
|
||
with torch.no_grad(): | ||
model_output = model.common_step(input_example)[ | ||
0 | ||
] # common_step returns tuple (prediction, target, pred_std, _) | ||
|
||
log_model_input_example = { | ||
name: tensor.cpu().numpy() | ||
for name, tensor in zip( | ||
["init_states", "target_states", "forcing", "target_times"], | ||
input_example, | ||
) | ||
} | ||
|
||
signature = infer_signature( | ||
log_model_input_example, model_output.cpu().numpy() | ||
) | ||
|
||
mlflow.pytorch.log_model( | ||
model, | ||
"model", | ||
signature=signature, | ||
) | ||
|
||
def create_input_example(self, data_module): | ||
|
||
if data_module.val_dataset is None: | ||
data_module.setup(stage="fit") | ||
|
||
data_loader = data_module.train_dataloader() | ||
batch_sample = next(iter(data_loader)) | ||
return batch_sample | ||
Comment on lines
+81
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
|
||
def _setup_training_logger(config, datastore, args, run_name): | ||
if config.training.logger == "wandb": | ||
logger = pl.loggers.WandbLogger( | ||
project=args.wandb_project, | ||
name=run_name, | ||
config=dict(training=vars(args), datastore=datastore._config), | ||
) | ||
elif config.training.logger == "mlflow": | ||
url = config.training.logger_url | ||
if url is None: | ||
raise ValueError( | ||
"MLFlow logger requires a URL to the MLFlow server" | ||
) | ||
logger = CustomMLFlowLogger( | ||
experiment_name=args.wandb_project, | ||
tracking_uri=url, | ||
) | ||
logger.log_hyperparams( | ||
dict(training=vars(args), datastore=datastore._config) | ||
) | ||
|
||
return logger | ||
|
||
|
||
@logger.catch | ||
def main(input_args=None): | ||
"""Main function for training and evaluating models.""" | ||
|
@@ -163,6 +272,12 @@ def main(input_args=None): | |
help="Number of example predictions to plot during evaluation " | ||
"(default: 1)", | ||
) | ||
parser.add_argument( | ||
"--save_predictions", | ||
action="store_true", | ||
help="If predictions should be saved to disk as a zarr dataset " | ||
"(default: false)", | ||
) | ||
|
||
# Logger Settings | ||
parser.add_argument( | ||
|
@@ -261,24 +376,30 @@ def main(input_args=None): | |
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" | ||
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" | ||
) | ||
|
||
training_logger = _setup_training_logger( | ||
config=config, datastore=datastore, args=args, run_name=run_name | ||
) | ||
|
||
checkpoint_callback = pl.callbacks.ModelCheckpoint( | ||
dirpath=f"saved_models/{run_name}", | ||
filename="min_val_loss", | ||
monitor="val_mean_loss", | ||
mode="min", | ||
save_last=True, | ||
) | ||
logger = pl.loggers.WandbLogger( | ||
project=args.wandb_project, | ||
name=run_name, | ||
config=dict(training=vars(args), datastore=datastore._config), | ||
) | ||
trainer = pl.Trainer( | ||
max_epochs=args.epochs, | ||
deterministic=True, | ||
strategy="ddp", | ||
devices=4, | ||
# devices=[1,2], | ||
# devices=[0, 1, 2], | ||
# strategy="auto", | ||
# devices=1, # For eval mode | ||
# num_nodes=1, # For eval mode | ||
accelerator=device_name, | ||
logger=logger, | ||
logger=training_logger, | ||
log_every_n_steps=1, | ||
callbacks=[checkpoint_callback], | ||
check_val_every_n_epoch=args.val_interval, | ||
|
@@ -287,11 +408,15 @@ def main(input_args=None): | |
|
||
# Only init once, on rank 0 only | ||
if trainer.global_rank == 0: | ||
utils.init_wandb_metrics( | ||
logger, val_steps=args.val_steps_to_log | ||
) # Do after wandb.init | ||
utils.init_training_logger_metrics( | ||
training_logger, val_steps=args.val_steps_to_log | ||
) # Do after initializing logger | ||
if args.eval: | ||
trainer.test(model=model, datamodule=data_module, ckpt_path=args.load) | ||
trainer.test( | ||
model=model, | ||
datamodule=data_module, | ||
ckpt_path=args.load, | ||
) | ||
else: | ||
trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we actually want the config to contain, vs what should be cmd-line arguments? I would have thought that the choice of logger would be an argparse flag, in a similar way as the plotting choices. My thought process is that logging/plotting does not affect the end product (trained model) whereas all the current options in the config does. But we are not really consistent with this divide either, as there are plenty of argparse options currently that change the model training.