Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
mist demo
  • Loading branch information
anoushka2000 authored and abhutani2000 committed Jul 16, 2024
0 parents commit b99047d
Show file tree
Hide file tree
Showing 26 changed files with 1,369 additions and 0 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
on:
push:
pull_request:
workflow_dispatch:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- run: pipx install poetry
- uses: actions/setup-python@v4
with:
python-version: 3.11.8
cache: 'poetry'
- run: poetry install
- run: source activate && pytest
18 changes: 18 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
on:
workflow_dispatch:
pull_request:
push:
branches: [master]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- run: pipx install poetry
- uses: actions/setup-python@v4
with:
python-version: 3.11.8
cache: 'poetry'
- run: poetry install
- uses: pre-commit/[email protected]
30 changes: 30 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
__pycache__/
*.sif
wandb/
.conda/
electrolyte-fm/
lightning_logs/
.venv/

# Log files
*.o*
*.e*

# Checkpoints
*.ckpt/
*.ckpt

# NSight
*.nsys-rep
*.sqlite
*.qdstrm

# Dataset
.cache/

# Spack
.spack-env/
spack.lock

# Notebook chekcpoints
.ipynb_checkpoints/
38 changes: 38 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: check-yaml
- id: end-of-file-fixer
- id: check-merge-conflict
- id: check-case-conflict
- id: check-added-large-files
- id: check-shebang-scripts-are-executable
exclude_types: [jinja]
- id: check-executables-have-shebangs
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.5
hooks:
- id: remove-crlf
- id: remove-tabs
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/FeryET/pre-commit-rust
rev: v1.1.0
hooks:
- id: fmt
args: [--manifest-path, smirk/Cargo.toml, --]
- id: cargo-check
args: [--manifest-path, smirk/Cargo.toml, --]
- repo: https://github.com/python-poetry/poetry
rev: 1.8.0
hooks:
- id: poetry-check
89 changes: 89 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Pre-training MIST (Molecular Insight SMILES Transformer)
This repository is an example of the pre-training workflow for a transformer trained on molecular datasets.

# Installation

MIST is trained primarily on Polaris, installation instructions for this system are provided here.
Installation may be slightly different for other systems.

## Polaris

1. Load conda
```shell
module purge
module use /soft/modulefiles/
module --ignore_cache load conda/2024-04-29
conda activate base
```

2. Install poetry + pipx
```shell
python -m pip install pipx
python -m pipx ensurepath
python -m pipx install --python $(which python) poetry
```

3. Install environment:
```shell
cd mist
poetry install
```

5. Install `ipykernel` and add a kernel for the environment.
```shell
source ./activate
python -m pip install ipykernel
python -m ipykernel install --user --name mist_demo
```

## Using the notebooks

The notebooks demonstrating the MIST pre-training workflow are in the `notebooks` directory. To run them:
1. Request an interactive session with one GPU node.
```
qsub -I -l select=1 -l filesystems=[home:filesystem] -l walltime=01:00:00 -q debug -A [AccountName]
```
2. Activate the environment
```shell
# Instructions for Polaris
module purge
module use /soft/modulefiles/
module --ignore_cache load conda/2024-04-29 gcc-native/12.3 PrgEnv-nvhpc
export CC=gcc-12
export CXX=g++-12
cd mist
source ./activate
```
4. Launch a `jupyter notebook` server and select the `mist_env` kernel.
```
jupyter notebook --ip $(hostname) --no-browser
```

## Data

The pre-training data is available on [Dropbox](https://www.dropbox.com/scl/fo/3z1lklbper07ojtp5t4iu/AHUEJ_3j5_CRVpWmcGLW3kQ?rlkey=2818imymvf5mk5byz0c7ei1ij&dl=0).
This data should be downloaded and extracted in the `sample_data` folder. It requires ~2.2GB of disk space.

```
sample_data
├── data
│ ├── train
│ │ ├── xaaa.txt
│ │ ├── xaab.txt
│ │ ├── ...
│ ├── test
│ │ ├── xaaa.txt
│ │ ├── xaab.txt
│ │ ├── ...
│ ├── val
│ │ ├── xaaa.txt
│ │ ├── xaab.txt
│ │ ├── ...
```

The data is pre-shuffled and split into training, validation and test sets with a 80:20:20 ratio.
The training dataset has `~0.25B` molecules, while the test and validation sets have `62M` molecules each.

## Checkpoint

A sample checkpoint is also available on Dropbox. This data should be downloaded and placed in the `sample_checkpoint` folder.
11 changes: 11 additions & 0 deletions activate
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
# Source this to activate the environment

# Activate virtual environment
source .venv/bin/activate

# Add NVIDIA libraries
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:$(realpath .venv/lib64):$(realpath .venv/lib)"

# DISABLE Compat Check
export DS_SKIP_CUDA_CHECK=1
Empty file added mist/__init__.py
Empty file.
Empty file added mist/data_modules/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions mist/data_modules/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from transformers import AutoTokenizer, PreTrainedTokenizerBase

class DataSetupMixin:

def setup_tokenizer(self, tokenizer: str):
# Locate Tokeniser and dataset
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
tokenizer,
trust_remote_code=True
)
128 changes: 128 additions & 0 deletions mist/data_modules/roberta_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pathlib import Path

import pytorch_lightning as pl
from datasets import IterableDataset, IterableDatasetDict, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling

from ..utils.tokenizer import load_tokenizer


class RobertaDataSet(pl.LightningDataModule):
def __init__(
self,
path: str,
tokenizer: str,
mlm_probability=0.15,
batch_size: int = 64,
val_batch_size=None,
num_workers=0,
prefetch_factor=None,
persistent_workers=False,
):
super().__init__()

# Locate Tokeniser and dataset
self.tokenizer = load_tokenizer(tokenizer)
self.vocab_size = self.tokenizer.vocab_size
self.path: Path = Path(path)
assert self.path.is_dir() or self.path.is_file()

self.mlm_probability = mlm_probability
self.batch_size = batch_size
self.val_batch_size = val_batch_size if val_batch_size else batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.persistent_workers = persistent_workers
self.save_hyperparameters()

def prepare_data(self):
self.__load_dataset()

def __load_dataset(self):
if not hasattr(self, "dataset"):
dataset = load_dataset(
str(self.path),
keep_in_memory=False,
streaming=True,
)
assert isinstance(dataset, IterableDatasetDict)
self.dataset = dataset

return self.dataset

def setup(self, stage: str) -> None:
self.data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm_probability=self.mlm_probability,
mlm=True,
)
ds = self.__load_dataset().map(
lambda batch: self.tokenizer(batch["text"]),
batched=True,
remove_columns="text",
)

# Setup to partition datasets over ranks
if self.trainer is not None:
rank = self.trainer.global_rank
world_size = self.trainer.world_size
ds_train: IterableDataset = ds["train"].shuffle(seed=42)
assert ds_train.n_shards % world_size == 0
assert ds["validation"].n_shards % world_size == 0
assert ds["test"].n_shards % world_size == 0
else:
rank = 0
world_size = 1
ds_train: IterableDataset = ds["train"].shuffle(seed=42)

# Partition Datasets
self.train_dataset: IterableDataset = split_dataset_by_node(
ds_train,
rank=rank,
world_size=world_size,
)
self.val_dataset: IterableDataset = split_dataset_by_node(
ds["validation"],
rank=rank,
world_size=world_size,
)
self.test_dataset: IterableDataset = split_dataset_by_node(
ds["test"],
rank=rank,
world_size=world_size,
)

def train_dataloader(self):
# Increment epoch to replicate shuffling
return DataLoader(
self.train_dataset,
collate_fn=self.data_collator,
batch_size=self.batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
pin_memory=True,
persistent_workers=self.persistent_workers,
)

def val_dataloader(self):
return DataLoader(
self.val_dataset,
collate_fn=self.data_collator,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
pin_memory=True,
persistent_workers=self.persistent_workers,
shuffle=False,
)

def test_dataset(self):
return DataLoader(
self.test_dataset,
collate_fn=self.data_collator,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
Empty file added mist/models/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions mist/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json
from pathlib import Path

import pytorch_lightning as pl

from ..utils.ckpt import SaveConfigWithCkpts


class DeepSpeedMixin:
@staticmethod
def load(checkpoint_dir, **kwargs):
print(checkpoint_dir)
return SaveConfigWithCkpts.load(checkpoint_dir, **kwargs)

def get_encoder(self):
raise NotImplementedError


class LoggingMixin(pl.LightningModule):

def on_train_epoch_start(self) -> None:
# Update the dataset's internal epoch counter
self.trainer.train_dataloader.dataset.set_epoch(self.trainer.current_epoch)
self.log(
"train/dataloader_epoch",
self.trainer.train_dataloader.dataset._epoch,
rank_zero_only=True,
sync_dist=True,
)
return super().on_train_epoch_start()

Loading

0 comments on commit b99047d

Please sign in to comment.