Skip to content

Commit

Permalink
Word2vec embedder transformer for sklearn pipeline - svm
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 16, 2023
1 parent b7708a7 commit 1a14a9b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,4 @@ def save_model(
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)

save_model(self.model, path)
save_model(self.model, path)
2 changes: 1 addition & 1 deletion tests/test_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_evaluate_and_log(x, y_true, y_pred):

def test_word2vec_embedding(mail):
embedder = Word2VecEmbedder()
embedding = embedder.fit_transform(mail)
embedding = embedder.transform(mail)[0]
assert len(embedding) == 300

def test_tp_sampler():
Expand Down
35 changes: 24 additions & 11 deletions utils/util_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import wandb
# from torch.utils.data import Sampler
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.base import BaseEstimator, TransformerMixin

def get_f1_score(
y_true: list[int],
Expand Down Expand Up @@ -91,35 +92,47 @@ def evaluate_and_log(
log_file.write(log_content)


class Word2VecEmbedder:
class Word2VecEmbedder(BaseEstimator, TransformerMixin):
def __init__(
self,
model_name: str = 'word2vec-google-news-300',
tokenizer: RegexpTokenizer(r'\w+') = RegexpTokenizer(r'\w+')
tokenizer=RegexpTokenizer(r'\w+')
):
self.model = gensim.downloader.load(model_name)
self.tokenizer = tokenizer

def fit_transform(
def fit(
self,
text: str,

X,
y=None
):
return self

def transform(
self,
X
):
"""Calculate Word2Vec embeddings for the given text.
Args:
text (str): text document.
X (list): List of text documents.
Returns:
np.ndarray: Word2Vec embeddings for the input text.
"""

# Initialize an array to store Word2Vec embeddings for the input text
words = self.tokenizer.tokenize(text) # Tokenize the document
word_vectors = [self.model[word] if word in self.model else np.zeros(self.model.vector_size) for word in words]
document_embedding = np.mean(word_vectors, axis=0) # Calculate the mean of word embeddings for the document
if isinstance(X, str):
X = [X]

embeddings = []

for text in X:
words = self.tokenizer.tokenize(text) # Tokenize the document
word_vectors = [self.model[word] if word in self.model else np.zeros(self.model.vector_size) for word in words]
document_embedding = np.mean(word_vectors, axis=0) # Calculate the mean of word embeddings for the document
embeddings.append(document_embedding)

return document_embedding.tolist()
return np.array(embeddings)


class TPSampler:
Expand Down

0 comments on commit 1a14a9b

Please sign in to comment.