Skip to content

Commit

Permalink
add keep column feature
Browse files Browse the repository at this point in the history
  • Loading branch information
speed1313 committed Nov 14, 2024
1 parent a36db8c commit 7f57408
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
11 changes: 10 additions & 1 deletion src/text2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
default="ja",
help="Target language for translation. This is used only for DeepL API.",
)
@click.option(
"--keep_columns",
type=str,
default="txt",
help="Columns to keep in the output dataset. Specify the column names separated by comma.",
)
def main(
model_id: str,
batch_size: int,
Expand All @@ -105,6 +111,7 @@ def main(
top_p: float,
max_tokens: int,
target_lang: str,
keep_columns: str,
):
# Text in source_column of the Dataset will be translated into Japanese.
state = State(0, 0, 0)
Expand All @@ -127,8 +134,10 @@ def main(
os.makedirs(output_dir, exist_ok=True)
state_path = os.path.join(output_dir, "state.jsonl")
ds = create_dataset(input_path, state)
# keep only the specified columns
ds = ds.select_columns(keep_columns.split(","))
# batch dataloader
data_loader = ds["train"].batch(batch_size=batch_size)
data_loader = ds.batch(batch_size=batch_size)

if use_wandb:
config_parameters = dict(locals())
Expand Down
49 changes: 32 additions & 17 deletions src/text2dataset/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,47 @@
from text2dataset.utils import State


def create_dataset(input_path: str, state: State) -> IterableDataset:
def create_dataset(
input_path: str, state: State, split: str = "train"
) -> IterableDataset:
"""
Create a Iterabledataset from the input path.
The input path can be a local file path or a HuggingFace Datasets dataset name.
Use the state to skip already processed examples.
"""
match input_path.split(".")[-1]:
case "csv":
ds = load_dataset("csv", data_files=input_path, streaming=True)
ds = load_dataset("csv", data_files=input_path, streaming=True, split=split)
case "json":
ds = load_dataset("json", data_files=input_path, streaming=True)
ds = load_dataset(
"json", data_files=input_path, streaming=True, split=split
)
case "jsonl":
ds = load_dataset("json", data_files=input_path, streaming=True)
ds = load_dataset(
"json", data_files=input_path, streaming=True, split=split
)
case "parquet":
ds = load_dataset("parquet", data_files=input_path, streaming=True)
ds = load_dataset(
"parquet", data_files=input_path, streaming=True, split=split
)
case "tar":
ds = load_dataset("webdataset", data_files=input_path, streaming=True)
ds = load_dataset(
"webdataset", data_files=input_path, streaming=True, split=split
)
case "arrow":
ds = load_dataset("arrow", data_files=input_path, streaming=True)
ds = load_dataset(
"arrow", data_files=input_path, streaming=True, split=split
)
case "txt":
ds = load_dataset("text", data_files=input_path, streaming=True)
ds = load_dataset(
"text", data_files=input_path, streaming=True, split=split
)
case _:
ds = load_dataset(input_path, streaming=True)
ds = load_dataset(input_path, streaming=True, split=split)

# skip already processed examples
if state.last_saved_example_num > 0:
ds["train"] = ds["train"].skip(state.last_saved_example_num)
ds = ds.skip(state.last_saved_example_num)

return ds

Expand All @@ -39,7 +53,7 @@ def test_create_dataset():
current_shard_id=0, last_saved_example_num=0, total_processed_examples=0
)
ds = create_dataset(input_path, state)
assert ds["train"] is not None
assert ds is not None
# error happing case
try:
create_dataset("hogehuga", state)
Expand All @@ -55,10 +69,11 @@ def test_create_dataset():
current_shard_id=0, last_saved_example_num=0, total_processed_examples=0
)
ds = create_dataset(input_path, state)
print(ds["train"])
# iterabledatasetdict to datasetdict
input_path = "data/english_quotes.json"
print(ds)
print(next(iter(ds)))
state = State(
current_shard_id=0, last_saved_example_num=10, total_processed_examples=10
)
ds = create_dataset(input_path, state)
print(ds["train"])
print(next(iter(ds["train"])))
ds.map(lambda x: x["txt"], batched=True)
print(ds)
print(next(iter(ds)))

0 comments on commit 7f57408

Please sign in to comment.