Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.1.1 #41

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2ecc995
Make the library compatible with pytorch-lightning 1.7 and specify py…
dyxohjl666 Mar 8, 2023
29aa3d4
Rename the model to match FLANT5
dyxohjl666 Mar 8, 2023
0d75ca4
Add function used to extract text by section and allow to set the suf…
dyxohjl666 Mar 8, 2023
3b22485
Modify the rules to identify different sections
dyxohjl666 Mar 8, 2023
ef7bd7f
Add FlanT5 datautils
dyxohjl666 Apr 2, 2023
ff7515b
Add controlled summarization based on FlanT5
dyxohjl666 Apr 2, 2023
98423d4
Add pipeline configs
dyxohjl666 Apr 3, 2023
6ee075e
Update the pdf processor
dyxohjl666 Apr 3, 2023
98e1bb9
Update the version
dyxohjl666 Apr 14, 2023
1737d37
Update datamodule for mixed Mup and SciSumm dataset
dyxohjl666 Apr 14, 2023
624a702
Add bert-score library
dyxohjl666 Apr 14, 2023
4e03647
Merge controlled summarization and uncontrolled summarization
dyxohjl666 Apr 14, 2023
10ec7f8
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 Apr 14, 2023
173e59c
Update the model url
dyxohjl666 Apr 21, 2023
b95b1f7
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 Apr 21, 2023
3773e06
Remove unused classes
dyxohjl666 Apr 29, 2023
eb130bf
Add function to extract abstracts
dyxohjl666 Apr 29, 2023
2a28523
Update the default num_beams=5
dyxohjl666 Apr 29, 2023
1b3076f
Remove test codes.
dyxohjl666 Apr 29, 2023
0088d8e
Update the rule to extract body_text
dyxohjl666 Apr 29, 2023
1dce402
Add rules for body_text extraction
dyxohjl666 May 4, 2023
6775041
Set the default num_beams to 5 to match the training procedure
dyxohjl666 May 4, 2023
1900306
Fix the problem when the model is trained on a different device from …
dyxohjl666 May 20, 2023
abfe09f
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 May 20, 2023
a90f649
Allow specify the checkpoint of model_class
dyxohjl666 Nov 18, 2023
5d25df4
Update pyproject.toml
dyxohjl666 Nov 18, 2023
e6f9911
Add interface for generate strategies
dyxohjl666 Nov 19, 2023
0677e99
Update dependencies
dyxohjl666 Nov 19, 2023
bb74b19
Update dependencies
dyxohjl666 Nov 19, 2023
d6d826a
Fix logger problem caused by higher version of lightning
dyxohjl666 Nov 19, 2023
c4bb633
Fix default value
dyxohjl666 Nov 19, 2023
71e80c9
Fix dependency problems
dyxohjl666 Nov 19, 2023
33254d0
Fix dependency problems
dyxohjl666 Nov 28, 2023
fb102c9
Merge branch 'main' into controlled-FFT
dyxohjl666 Nov 29, 2023
2bef53c
Upadate hook functions due to new version of lightning
dyxohjl666 Dec 7, 2023
d24eeec
Remove unnecessary log only used for test
dyxohjl666 Dec 7, 2023
a4f4aaa
Change the default strategy to greedy search
dyxohjl666 Dec 7, 2023
2b961ae
Update datasets library to keep consistent with demo
dyxohjl666 Dec 7, 2023
014e47d
Automatically put the data on the same device with the model for all …
dyxohjl666 Dec 7, 2023
fbcfa3f
Remove hyphen at the end of each line
dyxohjl666 Dec 7, 2023
973e450
Update pypi version
dyxohjl666 Dec 7, 2023
abe82e3
Merge remote-tracking branch 'origin/controlled-FFT' into controlled-FFT
dyxohjl666 Dec 7, 2023
fd27be6
Remove unused lines
dyxohjl666 Dec 7, 2023
9719874
Update cora_module.py
dyxohjl666 Dec 7, 2023
c44902b
Update pypi version
dyxohjl666 Dec 7, 2023
41e6c18
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
d0779ba
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
a2feae9
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
77f35ca
Update summarization.py
dyxohjl666 Dec 7, 2023
2a0a87d
Update summarization.py
dyxohjl666 Dec 7, 2023
3e35f44
Update summarization.py
dyxohjl666 Dec 7, 2023
fe6798f
Update data_utils.py
dyxohjl666 Dec 7, 2023
fe075b4
Update data_utils.py
dyxohjl666 Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "SciAssist"
version = "0.0.37"
version = "0.1.1"
authors = [
{ name="WING-NUS", email="[email protected]" },
]
Expand All @@ -28,25 +28,26 @@ dependencies = [
"lxml",
"matplotlib~=3.5.1",
"nltk~=3.7",
"numpy~=1.19.2",
"numpy",
"omegaconf~=2.2.2",
"PyPDF2~=2.10.7",
"python_magic~=0.4.18",
"pytorch_lightning~=1.7.1",
"requests~=2.21.0",
"pytorch_lightning~=2.0.4",
qolina marked this conversation as resolved.
Show resolved Hide resolved
"requests~=2.22.0",
"rich~=12.4.4",
"seaborn~=0.11.2",
"setuptools>=61.0",
"torch>=1.12.0",
"torchmetrics>=0.7.0",
"transformers~=4.19.2",
"torchmetrics==0.11.4",
"transformers~=4.30.2",
"wandb~=0.12.19",
"pdfminer.six",
"pandas~=1.4.3",
"pytorch-crf",
"torchcrf",
"sacremoses",
"seqeval",
"pytest~=7.4.3"
]


Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# --------- pytorch --------- #
torch>=1.10.0
torchvision>=0.11.0
pytorch-lightning~=1.7.1
pytorch-lightning~=2.0.4
torchmetrics>=0.7.0

# --------- hydra --------- #
Expand All @@ -28,13 +28,13 @@ pyrootutils
python-dotenv~=0.20.0
protobuf~=3.19.0
rich~=12.4.4
pytest~=7.1.2
pytest~=7.4.3
# sh~=1.14.2
pudb # debugger

seaborn~=0.11.2
omegaconf~=2.2.2
transformers~=4.19.2
transformers~=4.30.2
packaging~=21.3
datasets~=2.2.2
beautifulsoup4~=4.9.0
Expand All @@ -46,7 +46,7 @@ pdfminer.six # windows pdf processing

# --------- doc2json --------- #
boto3~=1.9.147
requests~=2.21.0
requests~=2.22.0
Flask~=1.0.2
tqdm
lxml
Expand Down
13 changes: 7 additions & 6 deletions src/SciAssist/models/components/flant5_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ def forward(self, input_ids=None, attention_mask=None, labels=None):
logits=outputs.logits
)

def generate(self, input_ids=None, attention_mask=None, num_beams=5, num_return_sequences=1):
def generate(self, input_ids=None, attention_mask=None, num_beams=5, num_return_sequences=1, top_k=0, max_length=500, do_sample=True):
diversity_penalty = 0.0
if num_return_sequences>1:
diversity_penalty = 1.0
return self.flant5.generate(input_ids=input_ids, attention_mask=attention_mask,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
num_beam_groups=num_return_sequences,
diversity_penalty=diversity_penalty,
max_length=300,
do_sample=False,
no_repeat_ngram_size=5 )
# num_beam_groups=num_return_sequences,
dyxohjl666 marked this conversation as resolved.
Show resolved Hide resolved
# diversity_penalty=diversity_penalty,
top_k=top_k,
qolina marked this conversation as resolved.
Show resolved Hide resolved
max_length=max_length,
do_sample=do_sample,)
# no_repeat_ngram_size=5 )
19 changes: 14 additions & 5 deletions src/SciAssist/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# main developer: Yixi Ding <[email protected]>

from typing import Dict

import torch
from typing import Dict

from SciAssist import BASE_CACHE_DIR
from SciAssist.models.components.bert_dataset_extraction import BertForDatasetExtraction
Expand Down Expand Up @@ -47,7 +46,12 @@
"model": FlanT5ForSummarization,
"model_dict_url": "https://huggingface.co/spaces/dyxohjl666/Controlled-summarization/resolve/main/flant5-base-mup-scisumm-repeat5-kws.pt",
"data_utils": DataUtilsForFlanT5,
}
},
"flan-t5-xl": {
"model": FlanT5ForSummarization,
"model_dict_url": None,
"data_utils": DataUtilsForFlanT5,
},
},
"dataset-extraction": {
"default": {
Expand All @@ -59,7 +63,7 @@

}

def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"):
def load_model(config: Dict, checkpoint=None, cache_dir=BASE_CACHE_DIR, device="gpu"):
qolina marked this conversation as resolved.
Show resolved Hide resolved
'''

Args:
Expand All @@ -77,7 +81,12 @@ def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"):

print("Loading the model...")
model_class = config["model"]
model = model_class(cache_dir=cache_dir)

if checkpoint!=None:
model = model_class(cache_dir=cache_dir,model_checkpoint=checkpoint)
else:
model = model_class(cache_dir=cache_dir)

map_location=None
if device == "cpu":
map_location = torch.device("cpu")
Expand Down
4 changes: 2 additions & 2 deletions src/SciAssist/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Pipeline():

"""

def __init__(self, task_name: str, model_name: str = "default", device="gpu",
def __init__(self, task_name: str, model_name: str = "default", checkpoint: str = None, device="gpu",
cache_dir=None, output_dir=None, temp_dir=None):

self.device = device
Expand All @@ -37,7 +37,7 @@ def __init__(self, task_name: str, model_name: str = "default", device="gpu",

self.config = TASKS[task_name][model_name]
self.model_name = model_name
self.model = load_model(config=self.config, cache_dir=self.cache_dir, device=self.device)
self.model = load_model(config=self.config, checkpoint=checkpoint,cache_dir=self.cache_dir, device=self.device)
if device in ["cuda", "gpu"] and torch.cuda.is_available():
self.device = torch.device("cuda")
self.model.cuda()
Expand Down
48 changes: 31 additions & 17 deletions src/SciAssist/pipelines/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import os
from typing import List, Tuple, Optional, Dict

from datasets import Dataset

from SciAssist import BASE_TEMP_DIR, BASE_OUTPUT_DIR
from SciAssist.pipelines.pipeline import Pipeline
from SciAssist.pipelines.testing_pipeline import test
from SciAssist.utils.pdf2text import process_pdf_file, get_bodytext
from SciAssist.utils.windows_pdf2text import windows_get_bodytext
from datasets import Dataset


class Summarization(Pipeline):
Expand Down Expand Up @@ -63,8 +62,8 @@ def __init__(
max_target_length=300,
os_name=None,
):
super().__init__(task_name=task_name, model_name=model_name, device=device,
cache_dir=cache_dir, output_dir=output_dir, temp_dir=temp_dir)
super().__init__(task_name=task_name, model_name=model_name, checkpoint=checkpoint,device=device,
cache_dir=cache_dir, output_dir=output_dir, temp_dir=temp_dir, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove last comma,


self.data_utils = self.data_utils(
tokenizer=tokenizer,
Expand All @@ -81,11 +80,14 @@ def predict(
self, input: str, type: str = "pdf",
output_dir=None,
temp_dir=None,
num_beams=1,
num_beams=5,
num_return_sequences=1,
save_results=True,
length = None,
keywords: List[str] = None
keywords: List[str] = None,
top_k=0,
max_length=500,
do_sample=False,
qolina marked this conversation as resolved.
Show resolved Hide resolved
):
"""

Expand Down Expand Up @@ -143,14 +145,14 @@ def predict(

if type in ["str", "string"]:
results = self._summarize_for_string(example=input, num_beams=num_beams,
num_return_sequences=num_return_sequences,length=length, keywords=keywords)
num_return_sequences=num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)
elif type in ["txt", "text"]:
results = self._summarize_for_text(filename=input, num_beams=num_beams,
num_return_sequences=num_return_sequences,length=length, keywords=keywords)
num_return_sequences=num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)
elif type == "pdf":
results = self._summarize_for_pdf(filename=input, output_dir=output_dir, temp_dir=temp_dir,
num_beams=num_beams, num_return_sequences=num_return_sequences,
length=length, keywords=keywords)
length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)

# Save predicted results as a text file
if save_results and type not in ["str", "string"]:
Expand All @@ -174,7 +176,10 @@ def _summarize(
num_beams=1,
num_return_sequences=1,
length=100,
keywords=None
keywords=None,
top_k=0,
max_length=500,
do_sample=False,
) -> List[str]:
"""
Summarize each text in the list.
Expand All @@ -197,7 +202,7 @@ def _summarize(
batch = self._to_device(batch)

# Get token ids of summary
pred = self.model.generate(batch["input_ids"], batch["attention_mask"], num_beams, num_return_sequences)
pred = self.model.generate(batch["input_ids"], batch["attention_mask"], num_beams, num_return_sequences,top_k=top_k,max_length=max_length,do_sample=do_sample)
# Convert token ids to text
decoded_preds = self.tokenizer.batch_decode(pred, skip_special_tokens=True)

Expand All @@ -210,7 +215,10 @@ def _summarize_for_string(
num_beams=1,
num_return_sequences=1,
length=100,
keywords=None
keywords=None,
top_k=0,
max_length=500,
do_sample=False,
) -> Tuple[str, str]:

"""
Expand All @@ -226,7 +234,7 @@ def _summarize_for_string(

"""
num = 10
res = self._summarize([example], num_beams, num_return_sequences,length=length, keywords=keywords)
res = self._summarize([example], num_beams, num_return_sequences,length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)
if length is not None:
num = 5*math.ceil(length/50)
# if keywords is not None:
Expand All @@ -239,7 +247,10 @@ def _summarize_for_text(
num_beams: int = 1,
num_return_sequences: int = 1,
length=100,
keywords=None
keywords=None,
top_k=0,
max_length=500,
do_sample=False,
) -> Tuple[str, str]:
"""

Expand All @@ -261,7 +272,7 @@ def _summarize_for_text(
with open(filename, "r") as f:
examples = f.readlines()
examples = [" ".join(examples)]
res = self._summarize(examples, num_beams, num_return_sequences,length=length,keywords=keywords)
res = self._summarize(examples, num_beams, num_return_sequences,length=length,keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)
# if keywords is not None:
# examples = [extract_related_sentences(examples[0], keywords[0],num)]
return {"summary": res, "raw_text": examples[0]}
Expand All @@ -274,7 +285,10 @@ def _summarize_for_pdf(
num_beams: int = 1,
num_return_sequences=1,
length = 100,
keywords = None
keywords = None,
top_k=0,
max_length=500,
do_sample=False,
) -> Dict:
"""
Summarize a document from a PDF file.
Expand All @@ -298,7 +312,7 @@ def _summarize_for_pdf(
text_file = windows_get_bodytext(path=filename, output_dir=output_dir)

# Do summarization
return self._summarize_for_text(text_file, num_beams=num_beams, num_return_sequences=num_return_sequences, length=length, keywords=keywords)
return self._summarize_for_text(text_file, num_beams=num_beams, num_return_sequences=num_return_sequences, length=length, keywords=keywords,top_k=top_k,max_length=max_length,do_sample=do_sample)


def evaluate(self):
Expand Down
14 changes: 8 additions & 6 deletions src/SciAssist/pipelines/testing_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from typing import List

import hydra
import os
import pytorch_lightning.loggers
qolina marked this conversation as resolved.
Show resolved Hide resolved
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import LightningLoggerBase
from typing import List

from SciAssist import utils

Expand All @@ -17,8 +16,11 @@ def test(config: DictConfig) -> None:
seed_everything(config.seed, workers=True)

# Convert relative ckpt path to absolute path if necessary
if not os.path.isabs(config.ckpt_path):
if not os.path.isabs(config.ckpt_path) and config.ckpt_path != "None":
config.ckpt_path = os.path.join(hydra.utils.get_original_cwd(), config.ckpt_path)
elif config.ckpt_path == "None":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow test on origin weights

config.ckpt_path = None
print(config.ckpt_path)

# Init lightning datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
Expand All @@ -29,7 +31,7 @@ def test(config: DictConfig) -> None:
model: LightningModule = hydra.utils.instantiate(config.model)

# Init lightning loggers
logger: List[LightningLoggerBase] = []
logger: List[pytorch_lightning.loggers.logger.Logger] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if "_target_" in lg_conf:
Expand Down
11 changes: 5 additions & 6 deletions src/SciAssist/utils/__init__.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logger is related to training procedure. FFT strategy requires lightning>=2.0. So have to change the logger class name to be compatible with newer version.

Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import warnings
from typing import List, Sequence

import pytorch_lightning as pl
import rich.syntax
import rich.tree
import warnings
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
from typing import List, Sequence


def get_logger(name=__name__) -> logging.Logger:
Expand Down Expand Up @@ -108,7 +107,7 @@ def log_hyperparameters(
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.base.LightningLoggerBase],
logger: List[pl.loggers.logger.Logger],
) -> None:
"""Controls which config parts are saved by Lightning loggers.

Expand Down Expand Up @@ -151,7 +150,7 @@ def finish(
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
logger: List[pl.loggers.logger.Logger],
) -> None:
"""Makes sure everything closed properly."""

Expand All @@ -160,4 +159,4 @@ def finish(
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb

wandb.finish()
wandb.finish()
Loading