Skip to content

Commit

Permalink
Unskip Thunder FSDP test (#1391)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored May 6, 2024
1 parent b1a43cd commit 6fd737d
Show file tree
Hide file tree
Showing 18 changed files with 34 additions and 32 deletions.
4 changes: 2 additions & 2 deletions .github/azure-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pr:

jobs:
- job: testing
timeoutInMinutes: "20"
timeoutInMinutes: "30"
cancelTimeoutInMinutes: "2"
pool: "lit-rtx-3090"
variables:
Expand Down Expand Up @@ -67,4 +67,4 @@ jobs:
env:
PL_RUN_CUDA_TESTS: "1"
displayName: "Standalone tests"
timeoutInMinutes: "5"
timeoutInMinutes: "10"
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import warnings

import pytest

warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
13 changes: 8 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import shutil
import warnings
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -50,6 +49,14 @@ def restore_default_dtype():
torch.set_default_dtype(torch.float32)


@pytest.fixture(autouse=True)
def destroy_process_group():
import torch.distributed

if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


class MockTokenizer:
"""A dummy tokenizer that encodes each character as its ASCII code."""

Expand Down Expand Up @@ -149,7 +156,3 @@ def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.C
bold=True,
purple=True, # oh yeah, branded pytest messages
)


# Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel)
warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*")
2 changes: 1 addition & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import torch
import yaml
from conftest import RunIf
from tests.conftest import RunIf
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch
import yaml
from conftest import RunIf
from tests.conftest import RunIf
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ci.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from conftest import RunIf
from tests.conftest import RunIf
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE


Expand Down
4 changes: 1 addition & 3 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

import os
from dataclasses import asdict
from pathlib import Path
from unittest.mock import ANY
from urllib.request import urlretrieve

import pytest
import torch
import yaml
from conftest import RunIf
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.falcon import FalconConfig, FalconForCausalLM
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
Expand All @@ -27,6 +24,7 @@
copy_weights_phi,
qkv_split,
)
from tests.conftest import RunIf


def test_convert_lit_checkpoint(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generate_sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import torch
import yaml
from conftest import RunIf
from tests.conftest import RunIf
from lightning import Fabric

from litgpt import Config
Expand Down
4 changes: 2 additions & 2 deletions tests/test_generate_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import pytest
import torch
import yaml
from conftest import RunIf
from test_generate_sequentially import find_forward_hooks

from litgpt import GPT, Config
from litgpt.generate.tp import tensor_parallel, tensor_parallel_linear
from litgpt.scripts.download import download_from_hub
from tests.conftest import RunIf
from tests.test_generate_sequentially import find_forward_hooks


def test_tensor_parallel_linear():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
import yaml
from conftest import RunIf
from lightning import Fabric
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
Expand All @@ -22,11 +21,12 @@
import litgpt.finetune.lora as module
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.lora import GPT as LoRAGPT
from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention
from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights
from litgpt.lora import GPT as LoRAGPT
from litgpt.model import GPT as BaseGPT
from litgpt.scripts.convert_hf_checkpoint import copy_weights_hf_llama
from tests.conftest import RunIf


def test_lora_layer_replacement():
Expand Down
4 changes: 1 addition & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from copy import deepcopy
from functools import partial
from pathlib import Path
from urllib.request import urlretrieve

import pytest
import torch
from conftest import RunIf
from lightning import Fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.fabric.utilities.init import _materialize_meta_tensors
Expand Down Expand Up @@ -37,6 +34,7 @@
copy_weights_hf_llama,
copy_weights_phi,
)
from tests.conftest import RunIf


@torch.inference_mode()
Expand Down
4 changes: 1 addition & 3 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
import os
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch
from conftest import RunIf
from lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy
from torch.utils.data import DataLoader

from test_utils import test_init_out_dir
from litgpt import pretrain
from litgpt.args import EvalArgs, TrainArgs
from litgpt.config import Config
from litgpt.pretrain import initialize_weights
from tests.conftest import RunIf


@RunIf(min_cuda_gpus=2, standalone=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,6 @@ def test_save_load_prompt_style(tmp_path):
save_prompt_style(CustomPromptStyle(), checkpoint_dir)
with open(checkpoint_dir / "prompt_style.yaml", "r", encoding="utf-8") as file:
contents = yaml.safe_load(file)
assert contents == {"class_path": "test_prompts.CustomPromptStyle"}
assert contents == {"class_path": "tests.test_prompts.CustomPromptStyle"}
loaded = load_prompt_style(checkpoint_dir)
assert isinstance(loaded, CustomPromptStyle)
2 changes: 1 addition & 1 deletion tests/test_thunder_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch
from conftest import RunIf
from tests.conftest import RunIf
from lightning import Fabric

# support running without installing as a package
Expand Down
4 changes: 1 addition & 3 deletions tests/test_thunder_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from conftest import RunIf
from tests.conftest import RunIf
from lightning.fabric import Fabric
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3

Expand Down Expand Up @@ -263,8 +263,6 @@ def set_up_planner(self, state_dict, metadata, is_coordinator):

@RunIf(min_cuda_gpus=2, thunder=True, standalone=True)
def test_save_load_sharded_checkpoint(tmp_path):
pytest.skip("Temporarily disabled, often exceeds 5 min timeout")

strategy = ThunderFSDPStrategy(state_dict_type="sharded", broadcast_from=0)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_thunder_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest.mock import Mock

import torch
from conftest import RunIf
from tests.conftest import RunIf
from torch.utils.data import DataLoader

from litgpt import Config
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unsloth_executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import torch
from conftest import RunIf

from litgpt import GPT, Config
from litgpt.model import apply_rope, build_rope_cache
from litgpt.utils import chunked_cross_entropy
from tests.conftest import RunIf


@RunIf(min_cuda_gpus=1, thunder=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import torch.nn.functional as F
import yaml
from conftest import RunIf
from tests.conftest import RunIf
from lightning import Fabric
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.plugins import BitsandbytesPrecision
Expand Down

0 comments on commit 6fd737d

Please sign in to comment.