From 465396163fc25db02f2fadeb6f1139420e5ab969 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 8 Mar 2024 13:33:37 +0100 Subject: [PATCH] Enable loading train and val split in JSON data module (#1054) --- litgpt/data/json.py | 64 ++++++++++++++++++++++++++---------- tests/data/test_json.py | 54 +++++++++++++++++++++++++++--- tutorials/prepare_dataset.md | 2 ++ 3 files changed, 98 insertions(+), 22 deletions(-) diff --git a/litgpt/data/json.py b/litgpt/data/json.py index 3f2d362922..20d227563b 100644 --- a/litgpt/data/json.py +++ b/litgpt/data/json.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, Tuple, Any import torch from torch.utils.data import random_split, DataLoader @@ -18,13 +18,14 @@ class JSON(LitDataModule): """Loads JSON data for supervised finetuning.""" json_path: Path - """A path to a JSON file containing the data. The file should contain a list of samples (dicts). - Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input' - (see Alpaca).""" + """A path to a JSON file or a directory with `train.json` and `val.json` containing the data. + The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output', + and can optionally have a key 'input' (see Alpaca).""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" - test_split_fraction: float = 0.1 - """The fraction of the dataset to use for the test/validation dataset. The rest is used for training.""" + test_split_fraction: Optional[float] = None + """The fraction of the dataset to use for the validation dataset. The rest is used for training. + Only applies if you passed in a single file to `json_path`.""" prompt_style: Union[str, PromptStyle] = "alpaca" """The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.""" ignore_index: int = -1 @@ -41,8 +42,16 @@ class JSON(LitDataModule): test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False) def __post_init__(self): - if not self.json_path.is_file(): - raise FileNotFoundError(f"The file {self.json_path} does not exist.") + if self.json_path.is_dir() and self.test_split_fraction is not None: + raise ValueError( + "If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and" + f" hence `test_split_fraction` should not be set. Got `{self.test_split_fraction=}`." + ) + if not self.json_path.exists(): + raise FileNotFoundError( + "The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files," + f" but '{self.json_path!s}' does not exist." + ) if isinstance(self.prompt_style, str): self.prompt_style = PromptStyle.from_name(self.prompt_style) @@ -57,16 +66,7 @@ def connect( self.max_seq_length = -1 if max_seq_length is None else max_seq_length def setup(self, stage: str = "") -> None: - with open(self.json_path, "r", encoding="utf-8") as file: - data = json.load(file) - - # Partition the dataset into train and test - train_data, test_data = random_split( - data, - [1.0 - self.test_split_fraction, self.test_split_fraction], - generator=torch.Generator().manual_seed(self.seed) - ) - train_data, test_data = list(train_data), list(test_data) + train_data, test_data = self.get_splits() self.train_dataset = SFTDataset( data=train_data, @@ -103,3 +103,31 @@ def val_dataloader(self) -> DataLoader: num_workers=self.num_workers, collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index) ) + + def get_splits(self) -> Tuple: + # A single file (gets split into train and test) + if self.json_path.is_file(): + data = load_split(self.json_path) + + # Partition the dataset into train and test + train_data, test_data = random_split( + data, + [1.0 - self.test_split_fraction, self.test_split_fraction], + generator=torch.Generator().manual_seed(self.seed) + ) + return train_data, test_data + + # A directory containing train.json and val.json + if (self.json_path / "train.json").is_file() and (self.json_path / f"val.json").is_file(): + train_data = load_split(self.json_path / "train.json") + test_data = load_split(self.json_path / f"val.json") + return train_data, test_data + + raise FileNotFoundError( + "The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files." + ) + + +def load_split(json_path: Path) -> Any: + with open(json_path, "r", encoding="utf-8") as file: + return json.load(file) diff --git a/tests/data/test_json.py b/tests/data/test_json.py index 3e4a88c45c..bf1356a3c4 100644 --- a/tests/data/test_json.py +++ b/tests/data/test_json.py @@ -24,10 +24,6 @@ def apply(self, prompt, **kwargs): with open(json_path, "w", encoding="utf-8") as fp: json.dump(mock_data, fp) - with pytest.raises(FileNotFoundError): - JSON(tmp_path / "not exist") - - # TODO: Make prompt template an argumenet data = JSON(json_path, test_split_fraction=0.5, prompt_style=Style(), num_workers=0) data.connect(tokenizer=mock_tokenizer, batch_size=2) data.prepare_data() # does nothing @@ -58,3 +54,53 @@ def apply(self, prompt, **kwargs): assert isinstance(train_dataloader.dataset.prompt_style, Style) assert isinstance(val_dataloader.dataset.prompt_style, Style) + +def test_json_input_validation(tmp_path): + from litgpt.data import JSON + + with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"): + JSON(tmp_path / "not exist") + + with pytest.raises(ValueError, match="`test_split_fraction` should not be set"): + JSON(tmp_path, test_split_fraction=0.5) + + data = JSON(tmp_path) + data.prepare_data() # does nothing + + # Empty directory + with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"): + data.setup() + + # Only train.json exists + (tmp_path / "train.json").touch() + with pytest.raises(FileNotFoundError, match="must be a file or a directory containing"): + data.setup() + + +def test_json_with_splits(tmp_path, mock_tokenizer): + from litgpt.data import JSON + + mock_train_data = [ + {"instruction": "Add", "input": "2+2", "output": "4"}, + {"instruction": "Subtract", "input": "5-3", "output": "2"}, + {"instruction": "Exponentiate", "input": "2^3", "output": "8"}, + ] + mock_test_data = [ + {"instruction": "Multiply", "input": "6*4", "output": "24"}, + {"instruction": "Divide", "input": "10/2", "output": "5"}, + ] + with open(tmp_path / "train.json", "w", encoding="utf-8") as fp: + json.dump(mock_train_data, fp) + with open(tmp_path / "val.json", "w", encoding="utf-8") as fp: + json.dump(mock_test_data, fp) + + data = JSON(tmp_path, num_workers=0) + data.connect(tokenizer=mock_tokenizer, batch_size=2) + data.prepare_data() # does nothing + data.setup() + + train_dataloader = data.train_dataloader() + val_dataloader = data.val_dataloader() + + assert len(train_dataloader) == 2 + assert len(val_dataloader) == 1 diff --git a/tutorials/prepare_dataset.md b/tutorials/prepare_dataset.md index 81bede75bc..d35b258372 100644 --- a/tutorials/prepare_dataset.md +++ b/tutorials/prepare_dataset.md @@ -366,6 +366,8 @@ python litgpt/finetune/lora.py \ --checkpoint_dir "checkpoints/tiiuae/falcon-7b" ``` +You can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split. +   ### Preparing Custom Datasets Using LitDataModule