Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable loading train and val split in JSON data module #1054

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions litgpt/data/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, Tuple, Any

import torch
from torch.utils.data import random_split, DataLoader
Expand All @@ -18,13 +18,14 @@ class JSON(LitDataModule):
"""Loads JSON data for supervised finetuning."""

json_path: Path
"""A path to a JSON file containing the data. The file should contain a list of samples (dicts).
Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input'
(see Alpaca)."""
"""A path to a JSON file or a directory with `train.json` and `val.json` containing the data.
The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output',
and can optionally have a key 'input' (see Alpaca)."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
test_split_fraction: float = 0.1
"""The fraction of the dataset to use for the test/validation dataset. The rest is used for training."""
test_split_fraction: Optional[float] = None
"""The fraction of the dataset to use for the validation dataset. The rest is used for training.
Only applies if you passed in a single file to `json_path`."""
prompt_style: Union[str, PromptStyle] = "alpaca"
"""The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles."""
ignore_index: int = -1
Expand All @@ -41,8 +42,16 @@ class JSON(LitDataModule):
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)

def __post_init__(self):
if not self.json_path.is_file():
raise FileNotFoundError(f"The file {self.json_path} does not exist.")
if self.json_path.is_dir() and self.test_split_fraction is not None:
raise ValueError(
"If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and"
f" hence `test_split_fraction` should not be set. Got `{self.test_split_fraction=}`."
)
if not self.json_path.exists():
raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files,"
f" but '{self.json_path!s}' does not exist."
)
if isinstance(self.prompt_style, str):
self.prompt_style = PromptStyle.from_name(self.prompt_style)

Expand All @@ -57,16 +66,7 @@ def connect(
self.max_seq_length = -1 if max_seq_length is None else max_seq_length

def setup(self, stage: str = "") -> None:
with open(self.json_path, "r", encoding="utf-8") as file:
data = json.load(file)

# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.test_split_fraction, self.test_split_fraction],
generator=torch.Generator().manual_seed(self.seed)
)
train_data, test_data = list(train_data), list(test_data)
train_data, test_data = self.get_splits()

self.train_dataset = SFTDataset(
data=train_data,
Expand Down Expand Up @@ -103,3 +103,31 @@ def val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
)

def get_splits(self) -> Tuple:
# A single file (gets split into train and test)
if self.json_path.is_file():
data = load_split(self.json_path)

# Partition the dataset into train and test
train_data, test_data = random_split(
data,
[1.0 - self.test_split_fraction, self.test_split_fraction],
generator=torch.Generator().manual_seed(self.seed)
)
return train_data, test_data

# A directory containing train.json and val.json
if (self.json_path / "train.json").is_file() and (self.json_path / f"val.json").is_file():
train_data = load_split(self.json_path / "train.json")
test_data = load_split(self.json_path / f"val.json")
return train_data, test_data

raise FileNotFoundError(
"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files."
)


def load_split(json_path: Path) -> Any:
with open(json_path, "r", encoding="utf-8") as file:
return json.load(file)
54 changes: 50 additions & 4 deletions tests/data/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ def apply(self, prompt, **kwargs):
with open(json_path, "w", encoding="utf-8") as fp:
json.dump(mock_data, fp)

with pytest.raises(FileNotFoundError):
JSON(tmp_path / "not exist")

# TODO: Make prompt template an argumenet
data = JSON(json_path, test_split_fraction=0.5, prompt_style=Style(), num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
Expand Down Expand Up @@ -58,3 +54,53 @@ def apply(self, prompt, **kwargs):
assert isinstance(train_dataloader.dataset.prompt_style, Style)
assert isinstance(val_dataloader.dataset.prompt_style, Style)


def test_json_input_validation(tmp_path):
from litgpt.data import JSON

with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"):
JSON(tmp_path / "not exist")

with pytest.raises(ValueError, match="`test_split_fraction` should not be set"):
JSON(tmp_path, test_split_fraction=0.5)

data = JSON(tmp_path)
data.prepare_data() # does nothing

# Empty directory
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()

# Only train.json exists
(tmp_path / "train.json").touch()
with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"):
data.setup()


def test_json_with_splits(tmp_path, mock_tokenizer):
from litgpt.data import JSON

mock_train_data = [
{"instruction": "Add", "input": "2+2", "output": "4"},
{"instruction": "Subtract", "input": "5-3", "output": "2"},
{"instruction": "Exponentiate", "input": "2^3", "output": "8"},
]
mock_test_data = [
{"instruction": "Multiply", "input": "6*4", "output": "24"},
{"instruction": "Divide", "input": "10/2", "output": "5"},
]
with open(tmp_path / "train.json", "w", encoding="utf-8") as fp:
json.dump(mock_train_data, fp)
with open(tmp_path / "val.json", "w", encoding="utf-8") as fp:
json.dump(mock_test_data, fp)

data = JSON(tmp_path, num_workers=0)
data.connect(tokenizer=mock_tokenizer, batch_size=2)
data.prepare_data() # does nothing
data.setup()

train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()

assert len(train_dataloader) == 2
assert len(val_dataloader) == 1
2 changes: 2 additions & 0 deletions tutorials/prepare_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@ python litgpt/finetune/lora.py \
--checkpoint_dir "checkpoints/tiiuae/falcon-7b"
```

You can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split.

 

### Preparing Custom Datasets Using LitDataModule
Expand Down
Loading