-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mist demo
- Loading branch information
0 parents
commit b99047d
Showing
26 changed files
with
1,369 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
on: | ||
push: | ||
pull_request: | ||
workflow_dispatch: | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- run: pipx install poetry | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: 3.11.8 | ||
cache: 'poetry' | ||
- run: poetry install | ||
- run: source activate && pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
push: | ||
branches: [master] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- run: pipx install poetry | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: 3.11.8 | ||
cache: 'poetry' | ||
- run: poetry install | ||
- uses: pre-commit/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
__pycache__/ | ||
*.sif | ||
wandb/ | ||
.conda/ | ||
electrolyte-fm/ | ||
lightning_logs/ | ||
.venv/ | ||
|
||
# Log files | ||
*.o* | ||
*.e* | ||
|
||
# Checkpoints | ||
*.ckpt/ | ||
*.ckpt | ||
|
||
# NSight | ||
*.nsys-rep | ||
*.sqlite | ||
*.qdstrm | ||
|
||
# Dataset | ||
.cache/ | ||
|
||
# Spack | ||
.spack-env/ | ||
spack.lock | ||
|
||
# Notebook chekcpoints | ||
.ipynb_checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.6.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
args: [--markdown-linebreak-ext=md] | ||
- id: check-yaml | ||
- id: end-of-file-fixer | ||
- id: check-merge-conflict | ||
- id: check-case-conflict | ||
- id: check-added-large-files | ||
- id: check-shebang-scripts-are-executable | ||
exclude_types: [jinja] | ||
- id: check-executables-have-shebangs | ||
- repo: https://github.com/Lucas-C/pre-commit-hooks | ||
rev: v1.5.5 | ||
hooks: | ||
- id: remove-crlf | ||
- id: remove-tabs | ||
- repo: https://github.com/PyCQA/isort | ||
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
- repo: https://github.com/psf/black | ||
rev: 24.3.0 | ||
hooks: | ||
- id: black | ||
- repo: https://github.com/FeryET/pre-commit-rust | ||
rev: v1.1.0 | ||
hooks: | ||
- id: fmt | ||
args: [--manifest-path, smirk/Cargo.toml, --] | ||
- id: cargo-check | ||
args: [--manifest-path, smirk/Cargo.toml, --] | ||
- repo: https://github.com/python-poetry/poetry | ||
rev: 1.8.0 | ||
hooks: | ||
- id: poetry-check |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Pre-training MIST (Molecular Insight SMILES Transformer) | ||
This repository is an example of the pre-training workflow for a transformer trained on molecular datasets. | ||
|
||
# Installation | ||
|
||
MIST is trained primarily on Polaris, installation instructions for this system are provided here. | ||
Installation may be slightly different for other systems. | ||
|
||
## Polaris | ||
|
||
1. Load conda | ||
```shell | ||
module purge | ||
module use /soft/modulefiles/ | ||
module --ignore_cache load conda/2024-04-29 | ||
conda activate base | ||
``` | ||
|
||
2. Install poetry + pipx | ||
```shell | ||
python -m pip install pipx | ||
python -m pipx ensurepath | ||
python -m pipx install --python $(which python) poetry | ||
``` | ||
|
||
3. Install environment: | ||
```shell | ||
cd mist | ||
poetry install | ||
``` | ||
|
||
5. Install `ipykernel` and add a kernel for the environment. | ||
```shell | ||
source ./activate | ||
python -m pip install ipykernel | ||
python -m ipykernel install --user --name mist_demo | ||
``` | ||
|
||
## Using the notebooks | ||
|
||
The notebooks demonstrating the MIST pre-training workflow are in the `notebooks` directory. To run them: | ||
1. Request an interactive session with one GPU node. | ||
``` | ||
qsub -I -l select=1 -l filesystems=[home:filesystem] -l walltime=01:00:00 -q debug -A [AccountName] | ||
``` | ||
2. Activate the environment | ||
```shell | ||
# Instructions for Polaris | ||
module purge | ||
module use /soft/modulefiles/ | ||
module --ignore_cache load conda/2024-04-29 gcc-native/12.3 PrgEnv-nvhpc | ||
export CC=gcc-12 | ||
export CXX=g++-12 | ||
cd mist | ||
source ./activate | ||
``` | ||
4. Launch a `jupyter notebook` server and select the `mist_env` kernel. | ||
``` | ||
jupyter notebook --ip $(hostname) --no-browser | ||
``` | ||
|
||
## Data | ||
|
||
The pre-training data is available on [Dropbox](https://www.dropbox.com/scl/fo/3z1lklbper07ojtp5t4iu/AHUEJ_3j5_CRVpWmcGLW3kQ?rlkey=2818imymvf5mk5byz0c7ei1ij&dl=0). | ||
This data should be downloaded and extracted in the `sample_data` folder. It requires ~2.2GB of disk space. | ||
|
||
``` | ||
sample_data | ||
├── data | ||
│ ├── train | ||
│ │ ├── xaaa.txt | ||
│ │ ├── xaab.txt | ||
│ │ ├── ... | ||
│ ├── test | ||
│ │ ├── xaaa.txt | ||
│ │ ├── xaab.txt | ||
│ │ ├── ... | ||
│ ├── val | ||
│ │ ├── xaaa.txt | ||
│ │ ├── xaab.txt | ||
│ │ ├── ... | ||
``` | ||
|
||
The data is pre-shuffled and split into training, validation and test sets with a 80:20:20 ratio. | ||
The training dataset has `~0.25B` molecules, while the test and validation sets have `62M` molecules each. | ||
|
||
## Checkpoint | ||
|
||
A sample checkpoint is also available on Dropbox. This data should be downloaded and placed in the `sample_checkpoint` folder. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
# Source this to activate the environment | ||
|
||
# Activate virtual environment | ||
source .venv/bin/activate | ||
|
||
# Add NVIDIA libraries | ||
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:$(realpath .venv/lib64):$(realpath .venv/lib)" | ||
|
||
# DISABLE Compat Check | ||
export DS_SKIP_CUDA_CHECK=1 |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from transformers import AutoTokenizer, PreTrainedTokenizerBase | ||
|
||
class DataSetupMixin: | ||
|
||
def setup_tokenizer(self, tokenizer: str): | ||
# Locate Tokeniser and dataset | ||
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( | ||
tokenizer, | ||
trust_remote_code=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from pathlib import Path | ||
|
||
import pytorch_lightning as pl | ||
from datasets import IterableDataset, IterableDatasetDict, load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
from torch.utils.data import DataLoader | ||
from transformers import DataCollatorForLanguageModeling | ||
|
||
from ..utils.tokenizer import load_tokenizer | ||
|
||
|
||
class RobertaDataSet(pl.LightningDataModule): | ||
def __init__( | ||
self, | ||
path: str, | ||
tokenizer: str, | ||
mlm_probability=0.15, | ||
batch_size: int = 64, | ||
val_batch_size=None, | ||
num_workers=0, | ||
prefetch_factor=None, | ||
persistent_workers=False, | ||
): | ||
super().__init__() | ||
|
||
# Locate Tokeniser and dataset | ||
self.tokenizer = load_tokenizer(tokenizer) | ||
self.vocab_size = self.tokenizer.vocab_size | ||
self.path: Path = Path(path) | ||
assert self.path.is_dir() or self.path.is_file() | ||
|
||
self.mlm_probability = mlm_probability | ||
self.batch_size = batch_size | ||
self.val_batch_size = val_batch_size if val_batch_size else batch_size | ||
self.num_workers = num_workers | ||
self.prefetch_factor = prefetch_factor | ||
self.persistent_workers = persistent_workers | ||
self.save_hyperparameters() | ||
|
||
def prepare_data(self): | ||
self.__load_dataset() | ||
|
||
def __load_dataset(self): | ||
if not hasattr(self, "dataset"): | ||
dataset = load_dataset( | ||
str(self.path), | ||
keep_in_memory=False, | ||
streaming=True, | ||
) | ||
assert isinstance(dataset, IterableDatasetDict) | ||
self.dataset = dataset | ||
|
||
return self.dataset | ||
|
||
def setup(self, stage: str) -> None: | ||
self.data_collator = DataCollatorForLanguageModeling( | ||
tokenizer=self.tokenizer, | ||
mlm_probability=self.mlm_probability, | ||
mlm=True, | ||
) | ||
ds = self.__load_dataset().map( | ||
lambda batch: self.tokenizer(batch["text"]), | ||
batched=True, | ||
remove_columns="text", | ||
) | ||
|
||
# Setup to partition datasets over ranks | ||
if self.trainer is not None: | ||
rank = self.trainer.global_rank | ||
world_size = self.trainer.world_size | ||
ds_train: IterableDataset = ds["train"].shuffle(seed=42) | ||
assert ds_train.n_shards % world_size == 0 | ||
assert ds["validation"].n_shards % world_size == 0 | ||
assert ds["test"].n_shards % world_size == 0 | ||
else: | ||
rank = 0 | ||
world_size = 1 | ||
ds_train: IterableDataset = ds["train"].shuffle(seed=42) | ||
|
||
# Partition Datasets | ||
self.train_dataset: IterableDataset = split_dataset_by_node( | ||
ds_train, | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
self.val_dataset: IterableDataset = split_dataset_by_node( | ||
ds["validation"], | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
self.test_dataset: IterableDataset = split_dataset_by_node( | ||
ds["test"], | ||
rank=rank, | ||
world_size=world_size, | ||
) | ||
|
||
def train_dataloader(self): | ||
# Increment epoch to replicate shuffling | ||
return DataLoader( | ||
self.train_dataset, | ||
collate_fn=self.data_collator, | ||
batch_size=self.batch_size, | ||
num_workers=self.num_workers, | ||
prefetch_factor=self.prefetch_factor, | ||
pin_memory=True, | ||
persistent_workers=self.persistent_workers, | ||
) | ||
|
||
def val_dataloader(self): | ||
return DataLoader( | ||
self.val_dataset, | ||
collate_fn=self.data_collator, | ||
batch_size=self.val_batch_size, | ||
num_workers=self.num_workers, | ||
prefetch_factor=self.prefetch_factor, | ||
pin_memory=True, | ||
persistent_workers=self.persistent_workers, | ||
shuffle=False, | ||
) | ||
|
||
def test_dataset(self): | ||
return DataLoader( | ||
self.test_dataset, | ||
collate_fn=self.data_collator, | ||
batch_size=self.val_batch_size, | ||
num_workers=self.num_workers, | ||
prefetch_factor=self.prefetch_factor, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import json | ||
from pathlib import Path | ||
|
||
import pytorch_lightning as pl | ||
|
||
from ..utils.ckpt import SaveConfigWithCkpts | ||
|
||
|
||
class DeepSpeedMixin: | ||
@staticmethod | ||
def load(checkpoint_dir, **kwargs): | ||
print(checkpoint_dir) | ||
return SaveConfigWithCkpts.load(checkpoint_dir, **kwargs) | ||
|
||
def get_encoder(self): | ||
raise NotImplementedError | ||
|
||
|
||
class LoggingMixin(pl.LightningModule): | ||
|
||
def on_train_epoch_start(self) -> None: | ||
# Update the dataset's internal epoch counter | ||
self.trainer.train_dataloader.dataset.set_epoch(self.trainer.current_epoch) | ||
self.log( | ||
"train/dataloader_epoch", | ||
self.trainer.train_dataloader.dataset._epoch, | ||
rank_zero_only=True, | ||
sync_dist=True, | ||
) | ||
return super().on_train_epoch_start() | ||
|
Oops, something went wrong.