Skip to content

Commit

Permalink
Enable loading train and val split in JSON data module (#1054)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Mar 8, 2024
1 parent 3aa7beb commit 42c1451
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 22 deletions.
64 changes: 46 additions & 18 deletions litgpt/data/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, Tuple, Any

import torch
from torch.utils.data import random_split, DataLoader
Expand All @@ -18,13 +18,14 @@ class JSON(LitDataModule):
"""Loads JSON data for supervised finetuning."""

json_path: Path
"""A path to a JSON file containing the data. The file should contain a list of samples (dicts).
Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input'
(see Alpaca)."""
"""A path to a JSON file or a directory with `train.json` and `val.json` containing the data.
The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output',
and can optionally have a key 'input' (see Alpaca)."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
test_split_fraction: float = 0.1
"""The fraction of the dataset to use for the test/validation dataset. The rest is used for training."""
test_split_fraction: Optional[float] = None
"""The fraction of the dataset to use for the validation dataset. The rest is used for training.
Only applies if you passed in a single file to `json_path`."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -1
Expand All @@ -41,8 +42,16 @@ class JSON(LitDataModule):
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)

def __post_init__(self):
if not self.json_path.is_file():
raise FileNotFoundError(f"The file {self.json_path} does not exist.")
if self.json_path.is_dir() and self.test_split_fraction is not None:
raise ValueError(
"If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and"
f" hence `test_split_fraction` should not be set. Got `{self.test_split_fraction=}`."
)
if not self.json_path.exists():
raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files,"
f" but '{self.json_path!s}' does not exist."
)
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)

Expand All @@ -57,16 +66,7 @@ def connect(
self.max_seq_length = -1 if max_seq_length is None else max_seq_length

def setup(self, stage: str = "") -> None:
with open(self.json_path, "r", encoding="utf-8") as file:
data = json.load(file)

# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.test_split_fraction, self.test_split_fraction],
generator=torch.Generator().manual_seed(self.seed)
)
train_data, test_data = list(train_data), list(test_data)
train_data, test_data = self.get_splits()

self.train_dataset = SFTDataset(
data=train_data,
Expand Down Expand Up @@ -103,3 +103,31 @@ def val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
)

def get_splits(self) -> Tuple:
# A single file (gets split into train and test)
if self.json_path.is_file():
data = load_split(self.json_path)

# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.test_split_fraction, self.test_split_fraction],
generator=torch.Generator().manual_seed(self.seed)
)
return train_data, test_data

# A directory containing train.json and val.json
if (self.json_path / "train.json").is_file() and (self.json_path / f"val.json").is_file():
train_data = load_split(self.json_path / "train.json")
test_data = load_split(self.json_path / f"val.json")
return train_data, test_data

raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files."
)


def load_split(json_path: Path) -> Any:
with open(json_path, "r", encoding="utf-8") as file:
return json.load(file)
54 changes: 50 additions & 4 deletions tests/data/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ def apply(self, prompt, **kwargs):
with open(json_path, "w", encoding="utf-8") as fp:
json.dump(mock_data, fp)

with pytest.raises(FileNotFoundError):
JSON(tmp_path / "not exist")

# TODO: Make prompt template an argumenet
data = JSON(json_path, test_split_fraction=0.5, prompt_style=Style(), num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
Expand Down Expand Up @@ -58,3 +54,53 @@ def apply(self, prompt, **kwargs):
assert isinstance(train_dataloader.dataset.prompt_style, Style)
assert isinstance(val_dataloader.dataset.prompt_style, Style)


def test_json_input_validation(tmp_path):
from litgpt.data import JSON

with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"):
JSON(tmp_path / "not exist")

with pytest.raises(ValueError, match="`test_split_fraction` should not be set"):
JSON(tmp_path, test_split_fraction=0.5)

data = JSON(tmp_path)
data.prepare_data() # does nothing

# Empty directory
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()

# Only train.json exists
(tmp_path / "train.json").touch()
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()


def test_json_with_splits(tmp_path, mock_tokenizer):
from litgpt.data import JSON

mock_train_data = [
{"instruction": "Add", "input": "2+2", "output": "4"},
{"instruction": "Subtract", "input": "5-3", "output": "2"},
{"instruction": "Exponentiate", "input": "2^3", "output": "8"},
]
mock_test_data = [
{"instruction": "Multiply", "input": "6*4", "output": "24"},
{"instruction": "Divide", "input": "10/2", "output": "5"},
]
with open(tmp_path / "train.json", "w", encoding="utf-8") as fp:
json.dump(mock_train_data, fp)
with open(tmp_path / "val.json", "w", encoding="utf-8") as fp:
json.dump(mock_test_data, fp)

data = JSON(tmp_path, num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
data.setup()

train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()

assert len(train_dataloader) == 2
assert len(val_dataloader) == 1
2 changes: 2 additions & 0 deletions tutorials/prepare_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ python litgpt/finetune/lora.py \
--checkpoint_dir "checkpoints/tiiuae/falcon-7b"
```

You can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split.

 

### Preparing Custom Datasets Using LitDataModule
Expand Down

0 comments on commit 42c1451

Please sign in to comment.