diff --git a/pyproject.toml b/pyproject.toml index 133be93..a5d38d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "SciAssist" -version = "0.0.37" +version = "0.1.4" authors = [ { name="WING-NUS", email="dingyixi@hotmail.com" }, ] @@ -23,23 +23,23 @@ classifiers = [ dependencies = [ "beautifulsoup4~=4.9.0", "chardet~=3.0.4", - "datasets~=2.2.2", + "datasets~=2.15.0", "hydra-core>=1.1.0", "lxml", "matplotlib~=3.5.1", "nltk~=3.7", - "numpy~=1.19.2", + "numpy", "omegaconf~=2.2.2", "PyPDF2~=2.10.7", "python_magic~=0.4.18", - "pytorch_lightning~=1.7.1", - "requests~=2.21.0", + "pytorch_lightning~=2.0.4", + "requests~=2.22.0", "rich~=12.4.4", "seaborn~=0.11.2", "setuptools>=61.0", "torch>=1.12.0", - "torchmetrics>=0.7.0", - "transformers~=4.19.2", + "torchmetrics==0.11.4", + "transformers~=4.30.2", "wandb~=0.12.19", "pdfminer.six", "pandas~=1.4.3", @@ -47,6 +47,7 @@ dependencies = [ "torchcrf", "sacremoses", "seqeval", + "pytest~=7.4.3" ] diff --git a/requirements.txt b/requirements.txt index 917155a..12f2726 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # --------- pytorch --------- # torch>=1.10.0 torchvision>=0.11.0 -pytorch-lightning~=1.7.1 +pytorch-lightning~=2.0.4 torchmetrics>=0.7.0 # --------- hydra --------- # @@ -28,15 +28,15 @@ pyrootutils python-dotenv~=0.20.0 protobuf~=3.19.0 rich~=12.4.4 -pytest~=7.1.2 +pytest~=7.4.3 # sh~=1.14.2 pudb # debugger seaborn~=0.11.2 omegaconf~=2.2.2 -transformers~=4.19.2 +transformers~=4.30.2 packaging~=21.3 -datasets~=2.2.2 +datasets~=2.15.0 beautifulsoup4~=4.9.0 matplotlib~=3.5.1 @@ -46,7 +46,7 @@ pdfminer.six # windows pdf processing # --------- doc2json --------- # boto3~=1.9.147 -requests~=2.21.0 +requests~=2.22.0 Flask~=1.0.2 tqdm lxml diff --git a/src/SciAssist/models/components/flant5_summarization.py b/src/SciAssist/models/components/flant5_summarization.py index 7665645..daca25f 100644 --- a/src/SciAssist/models/components/flant5_summarization.py +++ b/src/SciAssist/models/components/flant5_summarization.py @@ -28,15 +28,15 @@ def forward(self, input_ids=None, attention_mask=None, labels=None): logits=outputs.logits ) - def generate(self, input_ids=None, attention_mask=None, num_beams=5, num_return_sequences=1): + def generate(self, input_ids=None, attention_mask=None, num_beams=1, num_return_sequences=1, top_k=0, max_length=500, do_sample=False): diversity_penalty = 0.0 if num_return_sequences>1: diversity_penalty = 1.0 return self.flant5.generate(input_ids=input_ids, attention_mask=attention_mask, num_beams=num_beams, num_return_sequences=num_return_sequences, - num_beam_groups=num_return_sequences, - diversity_penalty=diversity_penalty, - max_length=300, - do_sample=False, - no_repeat_ngram_size=5 ) + diversity_penalty = diversity_penalty, + top_k=top_k, + max_length=max_length, + do_sample=do_sample,) + diff --git a/src/SciAssist/models/cora_module.py b/src/SciAssist/models/cora_module.py index e292a45..cc05f1a 100644 --- a/src/SciAssist/models/cora_module.py +++ b/src/SciAssist/models/cora_module.py @@ -11,6 +11,7 @@ from torchmetrics import MaxMetric from torchmetrics.classification.accuracy import Accuracy + from SciAssist.datamodules.components.cora_label import num_labels, LABEL_NAMES from SciAssist.models.components.bert_token_classifier import BertForTokenClassifier from SciAssist.utils.data_utils import DataUtilsForTokenClassification @@ -67,7 +68,7 @@ def training_step(self, batch: Any, batch_idx: int): self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) return {"loss": loss} - def training_epoch_end(self, outputs: List[Any]): + def on_training_epoch_end(self): pass def validation_step(self, batch: Any, batch_idx: int): @@ -91,7 +92,7 @@ def validation_step(self, batch: Any, batch_idx: int): self.log("val/macro_f1", macro_f1, on_step=False, on_epoch=True, prog_bar=True) return {"loss": loss, "preds": true_preds, "labels": true_labels} - def validation_epoch_end(self, outputs: List[Any]): + def on_validation_epoch_end(self): acc = self.val_acc.compute() self.val_acc_best.update(acc) self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True) @@ -129,7 +130,7 @@ def test_step(self, batch: Any, batch_idx: int): return {"loss": loss, "preds": true_preds, "labels": true_labels} - def test_epoch_end(self, outputs: List[Any]): + def on_test_epoch_end(self): # wandb.init() acc = self.test_acc.compute() micro_f1 = self.test_micro_f1.compute() diff --git a/src/SciAssist/models/mup_bart_module.py b/src/SciAssist/models/mup_bart_module.py index 7f4eac4..096d919 100644 --- a/src/SciAssist/models/mup_bart_module.py +++ b/src/SciAssist/models/mup_bart_module.py @@ -73,7 +73,7 @@ def training_step(self, batch: Any, batch_idx: int): self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False) return {"loss": loss} - def training_epoch_end(self, outputs: List[Any]): + def on_training_epoch_end(self): pass def validation_step(self, batch: Any, batch_idx: int): @@ -121,7 +121,7 @@ def validation_step(self, batch: Any, batch_idx: int): return result - def validation_epoch_end(self, outputs: List[Any]): + def on_validation_epoch_end(self): rouge = self.val_metric.compute() # bert = self.val_bertscore.compute() # self.val_best_Rouge1.update(rouge["rouge1_fmeasure"]) @@ -219,22 +219,14 @@ def test_step(self, batch: Any, batch_idx: int): return result - def test_epoch_end(self, outputs: List[Any]): + def on_test_epoch_end(self): # Save prediction results # with open(os.path.join(self.model.model_dir,"prediction.txt"),'w') as f: # for batch in outputs: # for res in batch["preds"]: # f.write(res) # f.write("\n") - - for batch in outputs: - for id,res in zip(batch['id'],batch["preds"]): - with open("/home/dingyx/project/SciAssist/data/pdfs/summary_flant5/" + str(id.item()) +".txt","a") as f: - # print("/home/dingyx/project/SciAssist/data/MUP_CTRLkeyword/" + str(id.item()) +".txt") - f.write(res) - f.write("\n") - # f.write(str(len(res.split(" ")))) - + P,R,F1 = bert_score.score(self.test_preds, self.test_labels, rescale_with_baseline=True, lang="en") # Compute average length of summaries @@ -260,7 +252,6 @@ def test_epoch_end(self, outputs: List[Any]): self.log("test/gen_len", self.test_gen_len, on_step=False, on_epoch=True, prog_bar=True) - def on_epoch_end(self): self.val_metric.reset() # self.val_bertscore.reset() diff --git a/src/SciAssist/pipelines/__init__.py b/src/SciAssist/pipelines/__init__.py index 9dc586b..99789be 100644 --- a/src/SciAssist/pipelines/__init__.py +++ b/src/SciAssist/pipelines/__init__.py @@ -1,8 +1,7 @@ # main developer: Yixi Ding -from typing import Dict - import torch +from typing import Dict from SciAssist import BASE_CACHE_DIR from SciAssist.models.components.bert_dataset_extraction import BertForDatasetExtraction @@ -47,7 +46,12 @@ "model": FlanT5ForSummarization, "model_dict_url": "https://huggingface.co/spaces/dyxohjl666/Controlled-summarization/resolve/main/flant5-base-mup-scisumm-repeat5-kws.pt", "data_utils": DataUtilsForFlanT5, - } + }, + "flan-t5-xl": { + "model": FlanT5ForSummarization, + "model_dict_url": None, + "data_utils": DataUtilsForFlanT5, + }, }, "dataset-extraction": { "default": { @@ -59,7 +63,7 @@ } -def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"): +def load_model(config: Dict, checkpoint=None, cache_dir=BASE_CACHE_DIR, device="gpu"): ''' Args: @@ -77,7 +81,12 @@ def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"): print("Loading the model...") model_class = config["model"] - model = model_class(cache_dir=cache_dir) + + if checkpoint!=None: + model = model_class(cache_dir=cache_dir,model_checkpoint=checkpoint) + else: + model = model_class(cache_dir=cache_dir) + map_location=None if device == "cpu": map_location = torch.device("cpu") diff --git a/src/SciAssist/pipelines/pipeline.py b/src/SciAssist/pipelines/pipeline.py index 9c5bef9..766dead 100644 --- a/src/SciAssist/pipelines/pipeline.py +++ b/src/SciAssist/pipelines/pipeline.py @@ -27,7 +27,7 @@ class Pipeline(): """ - def __init__(self, task_name: str, model_name: str = "default", device="gpu", + def __init__(self, task_name: str, model_name: str = "default", checkpoint: str = None, device="gpu", cache_dir=None, output_dir=None, temp_dir=None): self.device = device @@ -37,7 +37,7 @@ def __init__(self, task_name: str, model_name: str = "default", device="gpu", self.config = TASKS[task_name][model_name] self.model_name = model_name - self.model = load_model(config=self.config, cache_dir=self.cache_dir, device=self.device) + self.model = load_model(config=self.config, checkpoint=checkpoint,cache_dir=self.cache_dir, device=self.device) if device in ["cuda", "gpu"] and torch.cuda.is_available(): self.device = torch.device("cuda") self.model.cuda() diff --git a/src/SciAssist/pipelines/summarization.py b/src/SciAssist/pipelines/summarization.py index e4c1317..69794c5 100644 --- a/src/SciAssist/pipelines/summarization.py +++ b/src/SciAssist/pipelines/summarization.py @@ -4,13 +4,12 @@ import os from typing import List, Tuple, Optional, Dict -from datasets import Dataset - from SciAssist import BASE_TEMP_DIR, BASE_OUTPUT_DIR from SciAssist.pipelines.pipeline import Pipeline from SciAssist.pipelines.testing_pipeline import test from SciAssist.utils.pdf2text import process_pdf_file, get_bodytext from SciAssist.utils.windows_pdf2text import windows_get_bodytext +from datasets import Dataset class Summarization(Pipeline): @@ -63,7 +62,7 @@ def __init__( max_target_length=300, os_name=None, ): - super().__init__(task_name=task_name, model_name=model_name, device=device, + super().__init__(task_name=task_name, model_name=model_name, checkpoint=checkpoint,device=device, cache_dir=cache_dir, output_dir=output_dir, temp_dir=temp_dir) self.data_utils = self.data_utils( @@ -85,7 +84,10 @@ def predict( num_return_sequences=1, save_results=True, length = None, - keywords: List[str] = None + keywords: List[str] = None, + top_k=0, + max_length=500, + do_sample=False, ): """ @@ -143,14 +145,14 @@ def predict( if type in ["str", "string"]: results = self._summarize_for_string(example=input, num_beams=num_beams, - num_return_sequences=num_return_sequences,length=length, keywords=keywords) + num_return_sequences=num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) elif type in ["txt", "text"]: results = self._summarize_for_text(filename=input, num_beams=num_beams, - num_return_sequences=num_return_sequences,length=length, keywords=keywords) + num_return_sequences=num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) elif type == "pdf": results = self._summarize_for_pdf(filename=input, output_dir=output_dir, temp_dir=temp_dir, num_beams=num_beams, num_return_sequences=num_return_sequences, - length=length, keywords=keywords) + length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) # Save predicted results as a text file if save_results and type not in ["str", "string"]: @@ -162,11 +164,12 @@ def predict( return results def _to_device(self, batch): - if self.model_name in ["default", "bart-cnn-on-mup", "flan-t5", "t5"]: - return { - "input_ids": batch["input_ids"].to(self.device), - "attention_mask": batch["attention_mask"].to(self.device), - } + ''' Automatically move data to the same device with the model. + ''' + return { + "input_ids": batch["input_ids"].to(self.device), + "attention_mask": batch["attention_mask"].to(self.device), + } def _summarize( self, @@ -174,7 +177,10 @@ def _summarize( num_beams=1, num_return_sequences=1, length=100, - keywords=None + keywords=None, + top_k=0, + max_length=500, + do_sample=False, ) -> List[str]: """ Summarize each text in the list. @@ -197,7 +203,7 @@ def _summarize( batch = self._to_device(batch) # Get token ids of summary - pred = self.model.generate(batch["input_ids"], batch["attention_mask"], num_beams, num_return_sequences) + pred = self.model.generate(batch["input_ids"], batch["attention_mask"], num_beams, num_return_sequences,top_k=top_k,max_length=max_length,do_sample=do_sample) # Convert token ids to text decoded_preds = self.tokenizer.batch_decode(pred, skip_special_tokens=True) @@ -210,7 +216,10 @@ def _summarize_for_string( num_beams=1, num_return_sequences=1, length=100, - keywords=None + keywords=None, + top_k=0, + max_length=500, + do_sample=False, ) -> Tuple[str, str]: """ @@ -226,7 +235,7 @@ def _summarize_for_string( """ num = 10 - res = self._summarize([example], num_beams, num_return_sequences,length=length, keywords=keywords) + res = self._summarize([example], num_beams, num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) if length is not None: num = 5*math.ceil(length/50) # if keywords is not None: @@ -239,7 +248,10 @@ def _summarize_for_text( num_beams: int = 1, num_return_sequences: int = 1, length=100, - keywords=None + keywords=None, + top_k=0, + max_length=500, + do_sample=False, ) -> Tuple[str, str]: """ @@ -261,7 +273,7 @@ def _summarize_for_text( with open(filename, "r") as f: examples = f.readlines() examples = [" ".join(examples)] - res = self._summarize(examples, num_beams, num_return_sequences,length=length,keywords=keywords) + res = self._summarize(examples, num_beams, num_return_sequences,length=length,keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) # if keywords is not None: # examples = [extract_related_sentences(examples[0], keywords[0],num)] return {"summary": res, "raw_text": examples[0]} @@ -274,7 +286,10 @@ def _summarize_for_pdf( num_beams: int = 1, num_return_sequences=1, length = 100, - keywords = None + keywords = None, + top_k=0, + max_length=500, + do_sample=False, ) -> Dict: """ Summarize a document from a PDF file. @@ -298,9 +313,9 @@ def _summarize_for_pdf( text_file = windows_get_bodytext(path=filename, output_dir=output_dir) # Do summarization - return self._summarize_for_text(text_file, num_beams=num_beams, num_return_sequences=num_return_sequences, length=length, keywords=keywords) + return self._summarize_for_text(text_file, num_beams=num_beams, num_return_sequences=num_return_sequences, length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample) def evaluate(self): - return test() \ No newline at end of file + return test() diff --git a/src/SciAssist/pipelines/testing_pipeline.py b/src/SciAssist/pipelines/testing_pipeline.py index 3a52a9c..8f79e1c 100644 --- a/src/SciAssist/pipelines/testing_pipeline.py +++ b/src/SciAssist/pipelines/testing_pipeline.py @@ -1,10 +1,9 @@ -import os -from typing import List - import hydra +import os +import pytorch_lightning.loggers from omegaconf import DictConfig from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything -from pytorch_lightning.loggers import LightningLoggerBase +from typing import List from SciAssist import utils @@ -17,8 +16,11 @@ def test(config: DictConfig) -> None: seed_everything(config.seed, workers=True) # Convert relative ckpt path to absolute path if necessary - if not os.path.isabs(config.ckpt_path): + if not os.path.isabs(config.ckpt_path) and config.ckpt_path != "None": config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path) + elif config.ckpt_path == "None": + config.ckpt_path = None + print(config.ckpt_path) # Init lightning datamodule log.info(f"Instantiating datamodule <{config.datamodule._target_}>") @@ -29,7 +31,7 @@ def test(config: DictConfig) -> None: model: LightningModule = hydra.utils.instantiate(config.model) # Init lightning loggers - logger: List[LightningLoggerBase] = [] + logger: List[pytorch_lightning.loggers.logger.Logger] = [] if "logger" in config: for _, lg_conf in config.logger.items(): if "_target_" in lg_conf: diff --git a/src/SciAssist/utils/__init__.py b/src/SciAssist/utils/__init__.py index 7ba18bb..2a77b70 100644 --- a/src/SciAssist/utils/__init__.py +++ b/src/SciAssist/utils/__init__.py @@ -1,12 +1,11 @@ import logging -import warnings -from typing import List, Sequence - import pytorch_lightning as pl import rich.syntax import rich.tree +import warnings from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only +from typing import List, Sequence def get_logger(name=__name__) -> logging.Logger: @@ -108,7 +107,7 @@ def log_hyperparameters( datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], - logger: List[pl.loggers.base.LightningLoggerBase], + logger: List[pl.loggers.logger.Logger], ) -> None: """Controls which config parts are saved by Lightning loggers. @@ -151,7 +150,7 @@ def finish( datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], - logger: List[pl.loggers.LightningLoggerBase], + logger: List[pl.loggers.logger.Logger], ) -> None: """Makes sure everything closed properly.""" @@ -160,4 +159,4 @@ def finish( if isinstance(lg, pl.loggers.wandb.WandbLogger): import wandb - wandb.finish() + wandb.finish() \ No newline at end of file diff --git a/src/SciAssist/utils/data_utils.py b/src/SciAssist/utils/data_utils.py index c4fcc52..23f2f4a 100644 --- a/src/SciAssist/utils/data_utils.py +++ b/src/SciAssist/utils/data_utils.py @@ -1,11 +1,10 @@ -from typing import List, Dict - import nltk import numpy as np import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer from transformers import DataCollatorForSeq2Seq +from typing import List, Dict from SciAssist import BASE_CACHE_DIR from SciAssist.datamodules.components.cora_label import label2id as cora_label2id @@ -336,7 +335,7 @@ def collator(self): """ - The collating function. + The collating function used in DataLoader, which define the way how the raw data is converted to batch. Returns: `function`: A collating function. @@ -697,7 +696,8 @@ def collator(self): """ - The collating function. + The collating function used in DataLoader, which define the way how the raw data is converted to batch. + Returns: `function`: A collating function. @@ -855,7 +855,7 @@ def leng(prompt, length): prompts = [ leng(prompt,length) for (prompt,length) in zip(prompts,lengths)] inputs = [ prompt + ": " + raw_text for (prompt,raw_text) in zip(prompts, inputs) ] - + # Setup the tokenizer for inputs @@ -885,7 +885,8 @@ def collator(self): """ - The collating function. + The collating function used in DataLoader, which define the way how the raw data is converted to batch. + Returns: `function`: A collating function. @@ -958,333 +959,6 @@ def get_dataloader(self, dataset, inputs_column="text", labels_column="summary") return dataloader -# class DataUtilsForFiD(): -# """ -# -# Args: -# tokenizer (`PretrainedTokenizer`, default to None): -# The tokenizer for tokenization. -# checkpoint (`str`): -# The checkpoint from which the tokenizer is loaded. -# model_max_length (`int`, *optional*): The max sequence length the model accepts. -# max_source_length (`int`, *optional*): The max length of the input text. -# max_target_length (`int`, *optional*): The max length of the generated summary. -# """ -# -# def __init__(self, tokenizer = None, model_class = FiDT5, -# checkpoint = "google/flan-t5-large", -# model_max_length = 64, -# max_source_length = 64, -# max_target_length = 128, -# ): -# -# self.checkpoint = checkpoint -# self.model_max_length = model_max_length -# self.max_source_length = max_source_length -# self.max_target_length = max_target_length -# self.model_class = model_class -# -# if tokenizer is None: -# self.tokenizer = AutoTokenizer.from_pretrained( -# self.checkpoint, -# model_max_length = self.model_max_length, -# cache_dir=BASE_CACHE_DIR, -# ) -# else: -# self.tokenizer = tokenizer -# -# -# def tokenize_and_align_labels(self, examples, inputs_column="text", labels_column="summary", token_per_paragraph=50): -# -# """ -# -# Process the dataset for model input, for example, do tokenization and prepare label_ids. -# -# Args: -# examples (`Dataset`): { "text": [s1, s2, ...], "summary": [l1, l2, ...]} -# inputs (`str`): The name of input column -# labels (`str`): The name of target column -# -# Returns: -# `Dict`: {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label_ids } -# -# """ -# -# # Select input column -# inputs = examples[inputs_column] -# dataset = {"paragraphs": []} -# -# for input in inputs: -# texts = ["Please give a summary of the following text: "] -# tokens = input.split(" ") -# index = 0 -# while index+token_per_paragraph < min(len(tokens),120*token_per_paragraph): -# p = " ".join(tokens[index:index+token_per_paragraph]) -# texts.append(p) -# index += token_per_paragraph -# texts.append(" ".join(tokens[index:index+token_per_paragraph])) -# dataset["paragraphs"].append(texts) -# -# -# # Select target column -# if labels_column in examples.keys(): -# labels = examples[labels_column] -# dataset["labels"] = labels -# -# return dataset -# -# def collator(self): -# -# """ -# -# The collating function. -# -# Returns: -# `function`: A collating function. -# -# For example, **DataCollatorForSeq2Seq(...)**. -# -# You can also custom a collating function, but remember that `collator()` needs to return a **function**. -# """ -# -# from SciAssist.utils.collators.CollatorForFid import DataCollatorForFid -# -# return DataCollatorForFid(self.max_source_length, self.tokenizer, self.max_target_length) -# -# def postprocess(self, preds, labels): -# -# """ -# Process model's outputs and get the final results rather than simple ids. -# -# Args: -# preds (Tensor): Prediction labels, the output of the model. -# labels (Tensor): True labels -# -# Returns: -# `(LongTensor, LongTensor)`: decoded_preds, decoded_labels -# -# """ -# -# decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) -# -# labels = np.array(labels.to("cpu")) -# # Replace -100 in the labels as we can't decode them. -# labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) -# -# decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) -# -# decoded_preds = [pred.strip() for pred in decoded_preds] -# decoded_labels = [label.strip() for label in decoded_labels] -# -# # rougeLSum expects newline after each sentence -# decoded_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds] -# decoded_labels = ["\n".join(nltk.sent_tokenize(label)) for label in decoded_labels] -# -# return decoded_preds, decoded_labels -# -# def get_dataloader(self, dataset, inputs_column="text", labels_column="summary"): -# -# """ -# Generate DataLoader for a dataset. -# -# Args: -# dataset (`Dataset`): The raw dataset. -# inputs_column (`str`): Column name of the inputs. -# labels_column (`str`): Column name of the labels. -# -# Returns: -# `DataLoader`: A dataloader for the dataset. Will be used for inference. -# """ -# -# tokenized_example = dataset.map( -# lambda x: self.tokenize_and_align_labels(x, inputs_column=inputs_column, labels_column=labels_column), -# batched=True, -# remove_columns=dataset.column_names -# ) -# dataloader = DataLoader( -# dataset=tokenized_example, -# batch_size=8, -# collate_fn=self.collator(), -# ) -# -# return dataloader -# -# -# -# class DataUtilsForFrost(): -# """ -# -# Args: -# tokenizer (`PretrainedTokenizer`, default to None): -# The tokenizer for tokenization. -# checkpoint (`str`): -# The checkpoint from which the tokenizer is loaded. -# model_max_length (`int`, *optional*): The max sequence length the model accepts. -# max_source_length (`int`, *optional*): The max length of the input text. -# max_target_length (`int`, *optional*): The max length of the generated summary. -# """ -# -# -# def __init__(self, tokenizer = None, model_class = FrostForSummarization, -# checkpoint = "pegasus/frost", -# model_max_length = 1024, -# max_source_length = 1024, -# max_target_length = 128, -# ): -# -# self.checkpoint = checkpoint -# self.model_max_length = model_max_length -# self.max_source_length = max_source_length -# self.max_target_length = max_target_length -# self.model_class = model_class -# -# if tokenizer is None: -# self.tokenizer = PegasusTokenizer.from_pretrained( -# self.checkpoint, -# cache_dir=BASE_CACHE_DIR, -# model_max_length=self.model_max_length, -# ) -# else: -# self.tokenizer = tokenizer -# -# # FROST Constants -# self.ENTITYCHAIN_START_TOKEN = "[CONTENT]" -# self.SUMMARY_START_TOKEN = "[SUMMARY]" -# self.ENTITY_SEPARATOR = " | " -# self.ENTITY_SENTENCE_SEPARATOR = " ||| " -# -# # Prepare Spacy processor -# self.SPACY_MODEL_OR_PATH = "en_core_web_sm" -# self.SPACY_PROCESSOR = spacy.load(self.SPACY_MODEL_OR_PATH) -# -# def get_frost_labels(self, text): -# """Gets Spacy Frost processor.""" -# entity_plans = [] -# for text_sent in self.SPACY_PROCESSOR(text.replace("\n", " ")).sents: -# entity_plans.append( -# self.ENTITY_SEPARATOR.join( -# [entity.text for entity in self.SPACY_PROCESSOR(text_sent.text).ents])) -# text_with_entityplans = ( -# self.ENTITYCHAIN_START_TOKEN + " " + -# self.ENTITY_SENTENCE_SEPARATOR.join(entity_plans) + " " + -# self.SUMMARY_START_TOKEN + " " + text) -# return text_with_entityplans -# -# -# def tokenize_and_align_labels(self, examples, inputs_column="text", labels_column="summary"): -# -# """ -# -# Process the dataset for model input, for example, do tokenization and prepare label_ids. -# -# Args: -# examples (`Dataset`): { "text": [s1, s2, ...], "summary": [l1, l2, ...]} -# inputs (`str`): The name of input column -# labels (`str`): The name of target column -# -# Returns: -# `Dict`: {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label_ids } -# -# """ -# -# # Select input column -# inputs = examples[inputs_column] -# -# # Setup the tokenizer for inputs -# model_inputs = self.tokenizer(inputs, max_length=self.max_target_length, padding="max_length", truncation=True) -# -# # Select target column -# if labels_column in examples.keys(): -# labels = examples[labels_column] -# labels = [self.get_frost_labels(label) for label in labels] -# -# # Setup the tokenizer for targets -# with self.tokenizer.as_target_tokenizer(): -# labels = self.tokenizer(labels, max_length=self.max_target_length, padding="max_length", truncation=True) -# # Ignore padding in the loss -# labels["input_ids"] = [ -# [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] -# ] -# -# model_inputs["labels"] = labels["input_ids"] -# -# return model_inputs -# -# def collator(self): -# -# """ -# -# The collating function. -# -# Returns: -# `function`: A collating function. -# -# For example, **DataCollatorForSeq2Seq(...)**. -# -# You can also custom a collating function, but remember that `collator()` needs to return a **function**. -# """ -# -# -# return DataCollatorForSeq2Seq(self.tokenizer, model=self.model_class, pad_to_multiple_of=8) -# -# def postprocess(self, preds, labels): -# -# """ -# Process model's outputs and get the final results rather than simple ids. -# -# Args: -# preds (Tensor): Prediction labels, the output of the model. -# labels (Tensor): True labels -# -# Returns: -# `(LongTensor, LongTensor)`: decoded_preds, decoded_labels -# -# """ -# -# decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) -# -# labels = np.array(labels.to("cpu")) -# # Replace -100 in the labels as we can't decode them. -# labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) -# -# decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=False) -# -# decoded_preds = [pred.strip() for pred in decoded_preds] -# decoded_labels = [label.strip() for label in decoded_labels] -# -# # rougeLSum expects newline after each sentence -# decoded_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds] -# decoded_labels = ["\n".join(nltk.sent_tokenize(label)) for label in decoded_labels] -# -# return decoded_preds, decoded_labels -# -# def get_dataloader(self, dataset, inputs_column="text", labels_column="summary"): -# -# """ -# Generate DataLoader for a dataset. -# -# Args: -# dataset (`Dataset`): The raw dataset. -# inputs_column (`str`): Column name of the inputs. -# labels_column (`str`): Column name of the labels. -# -# Returns: -# `DataLoader`: A dataloader for the dataset. Will be used for inference. -# """ -# -# tokenized_example = dataset.map( -# lambda x: self.tokenize_and_align_labels(x, inputs_column=inputs_column, labels_column=labels_column), -# batched=True, -# remove_columns=dataset.column_names -# ) -# dataloader = DataLoader( -# dataset=tokenized_example, -# batch_size=8, -# collate_fn=self.collator(), -# ) -# -# return dataloader - class DataUtilsForT5(): """ @@ -1366,7 +1040,8 @@ def collator(self): """ - The collating function. + The collating function used in DataLoader, which define the way how the raw data is converted to batch. + Returns: `function`: A collating function. @@ -1439,4 +1114,3 @@ def get_dataloader(self, dataset, inputs_column="text", labels_column="summary") return dataloader - diff --git a/src/SciAssist/utils/windows_pdf2text.py b/src/SciAssist/utils/windows_pdf2text.py index 4b45fa4..f507071 100644 --- a/src/SciAssist/utils/windows_pdf2text.py +++ b/src/SciAssist/utils/windows_pdf2text.py @@ -1,5 +1,4 @@ import os - from pdfminer.high_level import extract_pages from SciAssist import BASE_OUTPUT_DIR