Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mitchish mosaic run on its own branch #350

Merged
merged 28 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2f9e9f5
Run from a branch
dirkgr Oct 29, 2023
40195c2
No time limit, and new load path
dirkgr Oct 29, 2023
3e80561
Time limit, and new load path
dirkgr Oct 30, 2023
088daa9
fix merge conflicts
epwalsh Oct 31, 2023
5904517
Add kempner run script for mitch-ish
epwalsh Nov 1, 2023
844a441
log to my own dir
epwalsh Nov 2, 2023
9659f04
try "by_block" wrapping strategy
epwalsh Nov 2, 2023
85800c1
try "sized_based" wrapping strategy
epwalsh Nov 2, 2023
22aadc6
try microbatch size of 1
epwalsh Nov 2, 2023
39c3bcf
shorten script name
epwalsh Nov 2, 2023
54810f8
enable flash attention
epwalsh Nov 2, 2023
4ed386a
fix
epwalsh Nov 2, 2023
c481165
Add mitch-ish 256 for LUMI (#351)
epwalsh Nov 6, 2023
53f16e4
'rocm_smi' -> 'rocm-smi'
epwalsh Nov 7, 2023
2fffcd3
save final sharded checkpoint
epwalsh Nov 24, 2023
06caccd
Add ability to restart on new epoch
epwalsh Nov 24, 2023
3d26ae3
Start new epoch for mitchish
epwalsh Nov 24, 2023
148ca06
start decay LR again
epwalsh Dec 2, 2023
569381a
fix merge conflicts
epwalsh Dec 7, 2023
a44217a
clean up
epwalsh Dec 7, 2023
eb1f523
add beaker script
epwalsh Dec 7, 2023
87a8b70
fix-tune down to 0 LR
epwalsh Dec 11, 2023
63d4c38
Merge branch 'main' into mitchish
epwalsh Dec 15, 2023
a3b36e4
update kempner script
epwalsh Dec 15, 2023
bd81f46
Merge branch 'main' into mitchish
epwalsh Dec 15, 2023
466dba6
fix compile without act chkpting
epwalsh Dec 15, 2023
7f356ac
comment out extra flags
epwalsh Jan 5, 2024
f3a73dd
Merge branch 'main' into mitchish
epwalsh Jan 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions configs/mcli/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: main # make sure to update this!
# git_branch: mitchish
git_commit: 148ca062e7f1f7667d7fc0f4346e97467e66ce87
pip_install: -e .
ssh_clone: true
command: |-
Expand All @@ -28,6 +29,18 @@ command: |-
--nproc_per_node 8 \
scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
--run_name=v1_5-mix-mitch-ish \
--wandb.name=v1_5-mix-mitch-ish-mcli \
--wandb.name=v1_5-mix-mitch-ish-mcli-final \
--global_train_batch_size=2160 \
--load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish}'
--time_limit=169200

# We added these flags in order to get a final checkpoint where we decayed the LR down to 0.
# --eval_interval=100 \
# --save_interval=500 \
# --load_path=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000 \
# --remote_save_folder=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final \
# --epoch=1 \
# --optimizer.learning_rate=0.000023 \
# --scheduler.t_warmup=556000 \
# --scheduler.t_max=557000 \
# --scheduler.alpha_f=0.001 \
# --stop_at=557001
2 changes: 1 addition & 1 deletion configs/v1_5-mix-medium-mitch-ish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ optimizer:
metrics_log_interval: 10

scheduler:
name: cosine_with_warmup
name: linear_with_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this was linear the whole time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess when we first made this config we were thinking cosine. We've only ran it with linear though.

t_warmup: 5000
alpha_f: 0.1

Expand Down
75 changes: 53 additions & 22 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .exceptions import OlmoConfigurationError
from .initialization import ModuleType, init_weights
from .torch_util import ensure_finite_
from .util import pass_through_fn

__all__ = [
"LayerNormBase",
Expand Down Expand Up @@ -430,7 +429,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
self.__cache = cache
assert config.d_model % config.n_heads == 0

self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this confusing. Doesn't that mean it will compile only if we don't use checkpointing? As far as I know, compile never likes function pointers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did compile + checkpointing never work anyways?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we ever got compile to work with checkpointing, but I needed to make this change in order for compile to work without checkpointing.


# Dropout.
self.dropout = Dropout(config.residual_dropout)
Expand Down Expand Up @@ -492,7 +491,7 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin
if strategy == ActivationCheckpointingStrategy.fine_grained:
self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
else:
self._activation_checkpoint_fn = pass_through_fn
self._activation_checkpoint_fn = None

@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -673,12 +672,20 @@ def forward(
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(self.fused_dims, dim=-1)
if self._activation_checkpoint_fn is not None:
q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
Expand All @@ -687,9 +694,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -753,23 +766,35 @@ def forward(
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# shape of ff: (batch_size, seq_len, hidden_size)
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
if self._activation_checkpoint_fn is not None:
q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
self.fused_dims, dim=-1
)
else:
q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)

# Get attention scores.
# shape: (B, T, C)
att, cache = self._activation_checkpoint_fn(
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Apply output projections (and activation function) and sum the results.
# We keep these projections separate because we found that we got better throughput this
# way compared to fusing them.
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
if self._activation_checkpoint_fn is not None:
return (
x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
cache,
)
else:
return (
x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
cache,
)


class OlmoLlamaBlock(OlmoBlock):
Expand Down Expand Up @@ -874,9 +899,15 @@ def forward(
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x
Expand Down Expand Up @@ -945,7 +976,7 @@ def forward(
)
):
# shape: (batch_size, seq_len, d_model)
x, cache = self._activation_checkpoint_fn(
x, cache = self._activation_checkpoint_fn( # type: ignore
block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
)
else:
Expand Down
30 changes: 30 additions & 0 deletions scripts/beaker/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env bash

set -ex

CONFIG_PATH=configs/v1_5-mix-medium-mitch-ish-s3.yaml
NUM_NODES=4
ARGS='--activation_checkpointing=fine_grained wandb.name=v1_5-mix-mitch-ish-mcli-final --epoch=1 --optimizer.learning_rate=0.000023 --scheduler.t_warmup=556000 --scheduler.t_max=557000 --scheduler.alpha_f=0.001 --stop_at=557000'

gantry run \
--allow-dirty \
--workspace ai2/llm-testing \
--task-name mitchish-mcli-final \
--description mitchish-mcli-final \
--priority high \
--beaker-image olmo-torch2-gantry \
--cluster ai2/general-cirrascale-a100-80g-ib \
--gpus 8 \
--replicas "${NUM_NODES}" \
--nfs \
--mount /net/nfs.cirrascale/allennlp/petew/cache:/root/.cache \
--env LOG_FILTER_TYPE=local_rank0_only \
--env OMP_NUM_THREADS=8 \
--env OLMO_TASK=model \
--env-secret WANDB_API_KEY=WANDB_API_KEY \
--env-secret AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID \
--env-secret AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY \
--shared-memory 10GiB \
--venv base \
--yes \
-- /bin/bash -c "torchrun --nnodes ${NUM_NODES}:${NUM_NODES} --nproc-per-node 8 --rdzv_id=101 --rdzv_backend=c10d --rdzv_endpoint=\$BEAKER_LEADER_REPLICA_HOSTNAME:29400 scripts/train.py ${CONFIG_PATH} ${ARGS}"
52 changes: 52 additions & 0 deletions scripts/kempner/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash
#SBATCH --job-name=v1.5-mix-medium-mitch-ish
#SBATCH --account=kempner_lab
#SBATCH --output=/n/holyscratch01/kempner_lab/Lab/logs-petew/%j.log
#SBATCH --nodes=8 # Total number of nodes
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=16
#SBATCH --time=167:00:00
#SBATCH --mem=0 # All memory on the node
#SBATCH --partition=kempner_project

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MPICH_GPU_SUPPORT_ENABLED=1
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}
export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}

export PYTHONPATH=.:${PYTHONPATH}

# Try playing with max_split_size_mb if you run into OOM errors.
# export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:512

export DATA_PATH=/n/home06/dgroeneveld/data/preprocessed/olmo-mix
export EVAL_DATA_PATH=/n/home06/dgroeneveld/data/eval-data
export CHECKPOINTS_PATH=/n/home06/dgroeneveld/checkpoints

export PYTORCH_KERNEL_CACHE_PATH=/tmp/pytorch_kernel_cache/
mkdir -p $PYTORCH_KERNEL_CACHE_PATH

LOAD_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish/step556000-unsharded
# SAVE_PATH=s3://ai2-llm/checkpoints/7b/v1_5-mix-mitch-ish-final-tulu

srun \
"--cpus-per-task=$SLURM_CPUS_PER_TASK" \
--distribution=block:block \
--kill-on-bad-exit \
scripts/run_with_environment.sh \
$HOME/miniconda3/envs/LLM/bin/python -u scripts/train.py configs/v1_5-mix-medium-mitch-ish-s3.yaml \
"--run_name=kempner_${SLURM_JOB_ID}" \
--wandb.name=v1_5-mix-mitch-ish-final-tulu \
'--data.paths=[s3://ai2-llm/preprocessed/tulu-v2-sft-mixture/gpt-neox-20b-pii-special/data.npy,s3://ai2-llm/preprocessed/olmo-mix/v1_5-sample-9B/gpt-neox-20b-pii-special/data.npy]' \
--eval_interval=100 \
--save_interval=500 \
"--load_path=${LOAD_PATH}" \
--restore_dataloader=false \
--optimizer.learning_rate=0.000023 \
--scheduler.t_warmup=556000 \
--scheduler.alpha_f=0.001 \
--scheduler.t_max=558223 \
--stop_at=558223 \
--time_limit=$((167 * 60 * 60)) \
"--save_folder=/n/holyscratch01/kempner_lab/Lab/checkpoints/${SLURM_JOB_ID}/"
60 changes: 60 additions & 0 deletions scripts/lumi/mitch-ish-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/bin/bash
#SBATCH --job-name=v1.5-mix-medium-mitch-ish
#SBATCH --account=project_462000229
#SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log
#SBATCH --nodes=256 # Total number of nodes
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=6
#SBATCH --time=48:00:00
#SBATCH --time-min=24:00:00
#SBATCH --mem=0 # All memory on the node
#SBATCH --partition=standard-g

module load LUMI/22.08 partition/G

# export OLMO_CONTAINER=llm-lumi_latest.sif
export OLMO_CONTAINER=llm-lumi-torch21_latest.sif

export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export MPICH_GPU_SUPPORT_ENABLED=1
export NCCL_SOCKET_IFNAME=hsn
export NCCL_NET_GDR_LEVEL=3
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}
export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}
export CXI_FORK_SAFE=1
export CXI_FORK_SAFE_HP=1
export FI_CXI_DISABLE_CQ_HUGETLB=1

# We need to set this to avoid "Cassini Event Queue overflow detected." errors.
export FI_CXI_DEFAULT_CQ_SIZE=131072

#export NCCL_DEBUG=INFO
export PYTHONPATH=.:${PYTHONPATH}
export ROCM_PATH=/opt/rocm
export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.0/lib64

# Try playing with max_split_size_mb if you run into OOM errors.
#export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128

export DATA_PATH=$FLASH_DIR/preprocessed/olmo-mix
export CHECKPOINTS_PATH=$FLASH_DIR/checkpoints
export EVAL_DATA_PATH=$SCRATCH_DIR/eval-data

srun \
--cpus-per-task=$SLURM_CPUS_PER_TASK \
--distribution=block:block \
--kill-on-bad-exit \
scripts/run_with_environment.sh \
singularity exec \
-B"$PROJECT_DIR:$PROJECT_DIR" \
-B"$FLASH_DIR:$FLASH_DIR" \
-B"$SCRATCH_DIR:$SCRATCH_DIR" \
-B /opt/cray:/opt/cray \
-B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \
-B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \
$PROJECT_DIR/containers/$OLMO_CONTAINER \
python scripts/train.py configs/v1_5-mix-medium-mitch-ish.yaml ${@} \
--run_name=${SLURM_JOB_ID} \
--global_train_batch_size=4096 \
--max_duration=238418
Loading