-
Notifications
You must be signed in to change notification settings - Fork 0
/
doc_manager.py
88 lines (77 loc) · 2.92 KB
/
doc_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import re
import json
import torch
import textract
import numpy as np
from tqdm import trange
from transformers import AutoTokenizer, AutoModel
def get_all_files(folder):
files = []
for rf,_,fs in os.walk(folder):
for f in fs:
files.append(os.path.join(rf, f))
return files
class Processor:
def __init__(self, hf_backbone, np_path, class_to_id):
assert "bert" in hf_backbone.lower(), "Supports only BERT Models"
self.tokenizer = AutoTokenizer.from_pretrained(hf_backbone)
self.device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
self.model = AutoModel.from_pretrained(hf_backbone).to(self.device)
self.model.eval()
self.model_config = self.model.config
# instead of saving the entire torch model that can be very heavy we merge
# weights and biases into a single matrix and store as numpy object
#
# self.wb.shape = [max_position_embeddings + 1, 3] # 3 classes on which trained
# w.shape = [max_position_embeddings, 3]
# b.shape = [1, 3]
#
# you can train any model as you want with the code in the accompanying
# README file.
self.wb = np.load(np_path)
# class_to_id is path to a file that has the labels
with open(class_to_id, "r") as f:
self.class_to_id = json.load(f)
def classify(self, all_text):
maxlen = self.model_config.max_position_embeddings # maximum length supported by BERT model
encodings = self.tokenizer(all_text, return_tensors = "pt", padding = "longest")
with torch.no_grad():
encodings = {k:v[:, :maxlen].to(self.device) for k,v in encodings.items()}
output = self.model(**encodings).last_hidden_state
logits = torch.sum(output, dim = 1) # pool the outputs
logits = logits.cpu().numpy()
# now to the linear kernel
# logits = logits @ weights + bias
# max_classes = logits.argmax(-1)
classes = (logits @ self.wb[:-1, :] + self.wb[-1]).argmax(-1)
return classes
def read_pdf(self, filename) -> str:
if os.name == "nt":
# windows bypass for reading PDF files
import PyPDF2
with open(filename, 'rb') as pdf_file:
read_pdf = PyPDF2.PdfFileReader(pdf_file)
number_of_pages = read_pdf.getNumPages()
page = read_pdf.getPage(0)
page_content = page.extractText()
text = page_content
elif os.name == "posix":
#linux bypass for reading PDF files
text = textract.process(filename, method='pdfminer')
text = text.decode("utf-8")
return text
def process(self, files):
# load the text in the documents
all_text = []
pbar = trange(len(files))
for i in pbar:
f = files[i]
pbar.set_description(f"Opening >> {f[:30]}")
text = self.read_pdf(f)
text = re.sub("\s+", " ", text)
all_text.append(text)
# get embeddings from the model
classes = self.classify(all_text)
class_labels = [self.class_to_id[x] for x in classes]
return class_labels