Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.1.1 #41

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2ecc995
Make the library compatible with pytorch-lightning 1.7 and specify py…
dyxohjl666 Mar 8, 2023
29aa3d4
Rename the model to match FLANT5
dyxohjl666 Mar 8, 2023
0d75ca4
Add function used to extract text by section and allow to set the suf…
dyxohjl666 Mar 8, 2023
3b22485
Modify the rules to identify different sections
dyxohjl666 Mar 8, 2023
ef7bd7f
Add FlanT5 datautils
dyxohjl666 Apr 2, 2023
ff7515b
Add controlled summarization based on FlanT5
dyxohjl666 Apr 2, 2023
98423d4
Add pipeline configs
dyxohjl666 Apr 3, 2023
6ee075e
Update the pdf processor
dyxohjl666 Apr 3, 2023
98e1bb9
Update the version
dyxohjl666 Apr 14, 2023
1737d37
Update datamodule for mixed Mup and SciSumm dataset
dyxohjl666 Apr 14, 2023
624a702
Add bert-score library
dyxohjl666 Apr 14, 2023
4e03647
Merge controlled summarization and uncontrolled summarization
dyxohjl666 Apr 14, 2023
10ec7f8
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 Apr 14, 2023
173e59c
Update the model url
dyxohjl666 Apr 21, 2023
b95b1f7
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 Apr 21, 2023
3773e06
Remove unused classes
dyxohjl666 Apr 29, 2023
eb130bf
Add function to extract abstracts
dyxohjl666 Apr 29, 2023
2a28523
Update the default num_beams=5
dyxohjl666 Apr 29, 2023
1b3076f
Remove test codes.
dyxohjl666 Apr 29, 2023
0088d8e
Update the rule to extract body_text
dyxohjl666 Apr 29, 2023
1dce402
Add rules for body_text extraction
dyxohjl666 May 4, 2023
6775041
Set the default num_beams to 5 to match the training procedure
dyxohjl666 May 4, 2023
1900306
Fix the problem when the model is trained on a different device from …
dyxohjl666 May 20, 2023
abfe09f
Merge branch 'main' of https://github.com/WING-NUS/SciAssist into con…
dyxohjl666 May 20, 2023
a90f649
Allow specify the checkpoint of model_class
dyxohjl666 Nov 18, 2023
5d25df4
Update pyproject.toml
dyxohjl666 Nov 18, 2023
e6f9911
Add interface for generate strategies
dyxohjl666 Nov 19, 2023
0677e99
Update dependencies
dyxohjl666 Nov 19, 2023
bb74b19
Update dependencies
dyxohjl666 Nov 19, 2023
d6d826a
Fix logger problem caused by higher version of lightning
dyxohjl666 Nov 19, 2023
c4bb633
Fix default value
dyxohjl666 Nov 19, 2023
71e80c9
Fix dependency problems
dyxohjl666 Nov 19, 2023
33254d0
Fix dependency problems
dyxohjl666 Nov 28, 2023
fb102c9
Merge branch 'main' into controlled-FFT
dyxohjl666 Nov 29, 2023
2bef53c
Upadate hook functions due to new version of lightning
dyxohjl666 Dec 7, 2023
d24eeec
Remove unnecessary log only used for test
dyxohjl666 Dec 7, 2023
a4f4aaa
Change the default strategy to greedy search
dyxohjl666 Dec 7, 2023
2b961ae
Update datasets library to keep consistent with demo
dyxohjl666 Dec 7, 2023
014e47d
Automatically put the data on the same device with the model for all …
dyxohjl666 Dec 7, 2023
fbcfa3f
Remove hyphen at the end of each line
dyxohjl666 Dec 7, 2023
973e450
Update pypi version
dyxohjl666 Dec 7, 2023
abe82e3
Merge remote-tracking branch 'origin/controlled-FFT' into controlled-FFT
dyxohjl666 Dec 7, 2023
fd27be6
Remove unused lines
dyxohjl666 Dec 7, 2023
9719874
Update cora_module.py
dyxohjl666 Dec 7, 2023
c44902b
Update pypi version
dyxohjl666 Dec 7, 2023
41e6c18
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
d0779ba
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
a2feae9
Update mup_bart_module.py
dyxohjl666 Dec 7, 2023
77f35ca
Update summarization.py
dyxohjl666 Dec 7, 2023
2a0a87d
Update summarization.py
dyxohjl666 Dec 7, 2023
3e35f44
Update summarization.py
dyxohjl666 Dec 7, 2023
fe6798f
Update data_utils.py
dyxohjl666 Dec 7, 2023
fe075b4
Update data_utils.py
dyxohjl666 Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "SciAssist"
version = "0.0.37"
version = "0.1.4"
authors = [
{ name="WING-NUS", email="[email protected]" },
]
Expand All @@ -23,30 +23,31 @@ classifiers = [
dependencies = [
"beautifulsoup4~=4.9.0",
"chardet~=3.0.4",
"datasets~=2.2.2",
"datasets~=2.15.0",
"hydra-core>=1.1.0",
"lxml",
"matplotlib~=3.5.1",
"nltk~=3.7",
"numpy~=1.19.2",
"numpy",
"omegaconf~=2.2.2",
"PyPDF2~=2.10.7",
"python_magic~=0.4.18",
"pytorch_lightning~=1.7.1",
"requests~=2.21.0",
"pytorch_lightning~=2.0.4",
qolina marked this conversation as resolved.
Show resolved Hide resolved
"requests~=2.22.0",
"rich~=12.4.4",
"seaborn~=0.11.2",
"setuptools>=61.0",
"torch>=1.12.0",
"torchmetrics>=0.7.0",
"transformers~=4.19.2",
"torchmetrics==0.11.4",
"transformers~=4.30.2",
"wandb~=0.12.19",
"pdfminer.six",
"pandas~=1.4.3",
"pytorch-crf",
"torchcrf",
"sacremoses",
"seqeval",
"pytest~=7.4.3"
]


Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# --------- pytorch --------- #
torch>=1.10.0
torchvision>=0.11.0
pytorch-lightning~=1.7.1
pytorch-lightning~=2.0.4
torchmetrics>=0.7.0

# --------- hydra --------- #
Expand All @@ -28,15 +28,15 @@ pyrootutils
python-dotenv~=0.20.0
protobuf~=3.19.0
rich~=12.4.4
pytest~=7.1.2
pytest~=7.4.3
# sh~=1.14.2
pudb # debugger

seaborn~=0.11.2
omegaconf~=2.2.2
transformers~=4.19.2
transformers~=4.30.2
packaging~=21.3
datasets~=2.2.2
datasets~=2.15.0
beautifulsoup4~=4.9.0

matplotlib~=3.5.1
Expand All @@ -46,7 +46,7 @@ pdfminer.six # windows pdf processing

# --------- doc2json --------- #
boto3~=1.9.147
requests~=2.21.0
requests~=2.22.0
Flask~=1.0.2
tqdm
lxml
Expand Down
12 changes: 6 additions & 6 deletions src/SciAssist/models/components/flant5_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def forward(self, input_ids=None, attention_mask=None, labels=None):
logits=outputs.logits
)

def generate(self, input_ids=None, attention_mask=None, num_beams=5, num_return_sequences=1):
def generate(self, input_ids=None, attention_mask=None, num_beams=1, num_return_sequences=1, top_k=0, max_length=500, do_sample=False):
diversity_penalty = 0.0
if num_return_sequences>1:
diversity_penalty = 1.0
return self.flant5.generate(input_ids=input_ids, attention_mask=attention_mask,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
num_beam_groups=num_return_sequences,
diversity_penalty=diversity_penalty,
max_length=300,
do_sample=False,
no_repeat_ngram_size=5 )
diversity_penalty = diversity_penalty,
top_k=top_k,
qolina marked this conversation as resolved.
Show resolved Hide resolved
max_length=max_length,
do_sample=do_sample,)

7 changes: 4 additions & 3 deletions src/SciAssist/models/cora_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchmetrics import MaxMetric
from torchmetrics.classification.accuracy import Accuracy


from SciAssist.datamodules.components.cora_label import num_labels, LABEL_NAMES
from SciAssist.models.components.bert_token_classifier import BertForTokenClassifier
from SciAssist.utils.data_utils import DataUtilsForTokenClassification
Expand Down Expand Up @@ -67,7 +68,7 @@ def training_step(self, batch: Any, batch_idx: int):
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
return {"loss": loss}

def training_epoch_end(self, outputs: List[Any]):
def on_training_epoch_end(self):
pass
qolina marked this conversation as resolved.
Show resolved Hide resolved

def validation_step(self, batch: Any, batch_idx: int):
Expand All @@ -91,7 +92,7 @@ def validation_step(self, batch: Any, batch_idx: int):
self.log("val/macro_f1", macro_f1, on_step=False, on_epoch=True, prog_bar=True)
return {"loss": loss, "preds": true_preds, "labels": true_labels}

def validation_epoch_end(self, outputs: List[Any]):
def on_validation_epoch_end(self):
acc = self.val_acc.compute()
self.val_acc_best.update(acc)
self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True, prog_bar=True)
Expand Down Expand Up @@ -129,7 +130,7 @@ def test_step(self, batch: Any, batch_idx: int):

return {"loss": loss, "preds": true_preds, "labels": true_labels}

def test_epoch_end(self, outputs: List[Any]):
def on_test_epoch_end(self):
# wandb.init()
acc = self.test_acc.compute()
micro_f1 = self.test_micro_f1.compute()
Expand Down
17 changes: 4 additions & 13 deletions src/SciAssist/models/mup_bart_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def training_step(self, batch: Any, batch_idx: int):
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False)
return {"loss": loss}

def training_epoch_end(self, outputs: List[Any]):
def on_training_epoch_end(self):
pass

def validation_step(self, batch: Any, batch_idx: int):
Expand Down Expand Up @@ -121,7 +121,7 @@ def validation_step(self, batch: Any, batch_idx: int):

return result

def validation_epoch_end(self, outputs: List[Any]):
def on_validation_epoch_end(self):
rouge = self.val_metric.compute()
# bert = self.val_bertscore.compute()
# self.val_best_Rouge1.update(rouge["rouge1_fmeasure"])
Expand Down Expand Up @@ -219,22 +219,14 @@ def test_step(self, batch: Any, batch_idx: int):
return result


def test_epoch_end(self, outputs: List[Any]):
def on_test_epoch_end(self):
# Save prediction results
# with open(os.path.join(self.model.model_dir,"prediction.txt"),'w') as f:
# for batch in outputs:
# for res in batch["preds"]:
# f.write(res)
# f.write("\n")

for batch in outputs:
for id,res in zip(batch['id'],batch["preds"]):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter outputs is not supported in the new function

with open("/home/dingyx/project/SciAssist/data/pdfs/summary_flant5/" + str(id.item()) +".txt","a") as f:
# print("/home/dingyx/project/SciAssist/data/MUP_CTRLkeyword/" + str(id.item()) +".txt")
f.write(res)
f.write("\n")
# f.write(str(len(res.split(" "))))


P,R,F1 = bert_score.score(self.test_preds, self.test_labels,
rescale_with_baseline=True, lang="en")
# Compute average length of summaries
Expand All @@ -260,7 +252,6 @@ def test_epoch_end(self, outputs: List[Any]):
self.log("test/gen_len", self.test_gen_len, on_step=False, on_epoch=True, prog_bar=True)



def on_epoch_end(self):
self.val_metric.reset()
# self.val_bertscore.reset()
Expand Down
19 changes: 14 additions & 5 deletions src/SciAssist/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# main developer: Yixi Ding <[email protected]>

from typing import Dict

import torch
from typing import Dict

from SciAssist import BASE_CACHE_DIR
from SciAssist.models.components.bert_dataset_extraction import BertForDatasetExtraction
Expand Down Expand Up @@ -47,7 +46,12 @@
"model": FlanT5ForSummarization,
"model_dict_url": "https://huggingface.co/spaces/dyxohjl666/Controlled-summarization/resolve/main/flant5-base-mup-scisumm-repeat5-kws.pt",
"data_utils": DataUtilsForFlanT5,
}
},
"flan-t5-xl": {
"model": FlanT5ForSummarization,
"model_dict_url": None,
"data_utils": DataUtilsForFlanT5,
},
},
"dataset-extraction": {
"default": {
Expand All @@ -59,7 +63,7 @@

}

def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"):
def load_model(config: Dict, checkpoint=None, cache_dir=BASE_CACHE_DIR, device="gpu"):
qolina marked this conversation as resolved.
Show resolved Hide resolved
'''

Args:
Expand All @@ -77,7 +81,12 @@ def load_model(config: Dict, cache_dir=BASE_CACHE_DIR, device="gpu"):

print("Loading the model...")
model_class = config["model"]
model = model_class(cache_dir=cache_dir)

if checkpoint!=None:
model = model_class(cache_dir=cache_dir,model_checkpoint=checkpoint)
else:
model = model_class(cache_dir=cache_dir)

map_location=None
if device == "cpu":
map_location = torch.device("cpu")
Expand Down
4 changes: 2 additions & 2 deletions src/SciAssist/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Pipeline():

"""

def __init__(self, task_name: str, model_name: str = "default", device="gpu",
def __init__(self, task_name: str, model_name: str = "default", checkpoint: str = None, device="gpu",
cache_dir=None, output_dir=None, temp_dir=None):

self.device = device
Expand All @@ -37,7 +37,7 @@ def __init__(self, task_name: str, model_name: str = "default", device="gpu",

self.config = TASKS[task_name][model_name]
self.model_name = model_name
self.model = load_model(config=self.config, cache_dir=self.cache_dir, device=self.device)
self.model = load_model(config=self.config, checkpoint=checkpoint,cache_dir=self.cache_dir, device=self.device)
if device in ["cuda", "gpu"] and torch.cuda.is_available():
self.device = torch.device("cuda")
self.model.cuda()
Expand Down
Loading