Skip to content

Commit

Permalink
dataloader factory added
Browse files Browse the repository at this point in the history
  • Loading branch information
msalhab96 committed Jun 22, 2022
1 parent a0c035b commit 57d9dbb
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from typing import Union
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from interfaces import IDataLoader, IPipeline, IPadder


Expand Down Expand Up @@ -80,3 +80,27 @@ def __getitem__(self, idx: int):
mask = torch.BoolTensor(mask)
spk_id = torch.LongTensor([spk_id])
return speech, speech_length, mask, text, spk_id


def get_batch_loader(
data_loader: IDataLoader,
aud_pipeline: IPipeline,
text_pipeline: IPipeline,
aud_padder: IPadder,
text_padder: IPadder,
batch_size: int,
sep: str
):
return DataLoader(
Data(
data_loader=data_loader,
aud_pipeline=aud_pipeline,
text_pipeline=text_pipeline,
aud_padder=aud_padder,
text_padder=text_padder,
batch_size=batch_size,
sep=sep
),
batch_size=batch_size,
shuffle=False
)

0 comments on commit 57d9dbb

Please sign in to comment.