Skip to content

Commit

Permalink
finalize citation embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 10, 2024
1 parent 99ffa43 commit eb53bc8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
49 changes: 49 additions & 0 deletions core/data_utils/load_data_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
from typing import Dict, Tuple, List, Union
import torch
from lpda.lcc_3 import find_scc_direc, use_lcc_direc
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize
from gensim.models import Word2Vec
from tqdm import tqdm

FILE = 'core/dataset/ogbn_products_orig/ogbn-products.csv'
FILE_PATH = get_git_repo_root_path() + '/'
Expand Down Expand Up @@ -206,6 +211,7 @@ def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
cfg.include_negatives,
cfg.split_labels
)

return splits, text, data


Expand Down Expand Up @@ -274,11 +280,54 @@ def load_taplp_pwc_small(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:



def preprocess(text):
# Remove non-alphanumeric characters
text = re.sub(r'\W+', ' ', text)
# Tokenize and convert to lowercase
tokens = word_tokenize(text.lower())
return ' '.join(tokens)

# Function to get the average embedding for a whole text (e.g., title and abstract combined)
def get_average_embedding(text, model):
tokens = preprocess(text)
embeddings = [model.wv[token] for token in tokens if token in model.wv]
if embeddings:
return np.mean(embeddings, axis=0)
else:
# Return a zero vector if none of the tokens are in the vocabulary
return np.zeros(model.vector_size)

# TEST CODE
if __name__ == '__main__':
exit(-1)
args = init_cfg_test()
args = config_device(args)

splits, text, data = load_taglp_citationv8(args.data)

print(data)
from pdb import set_trace as st; st()
preprocessed_texts = [preprocess(t[0]) for t in tqdm(text)]
print(len(preprocessed_texts))
# Train a Word2Vec model

model = Word2Vec(sentences=preprocessed_texts, vector_size=128, window=5, min_count=1, workers=10)

w2v_nodefeat = np.array([get_average_embedding(t[0], model) for t in text])

x = torch.tensor(w2v_nodefeat, dtype=torch.float)

data.x = x
torch.save(data, f'citationv8_{args.data.method}.pt')
exit(-1)
vectorizer = TfidfVectorizer(max_features=128)
tfidf_matrix = vectorizer.fit_transform(preprocessed_texts)
from pdb import set_trace as st; st()
x = torch.tensor(tfidf_matrix.toarray(), dtype=torch.float)

data.x = x
torch.save(data, f'citationv8_{args.data.method}.pt')
exit(-1)
from lpda.lcc_3 import use_lcc
splits, text, data = load_taplp_pwc_small(args.data)
print(splits)
Expand Down
2 changes: 1 addition & 1 deletion core/data_utils/load_data_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def load_graph_citationv8() -> Data:
def load_text_citationv8() -> List[str]:
df = pd.read_csv(FILE_PATH + 'core/dataset/citationv8_orig/Citation-2015.csv')
return [
f'Text: {ti}\n'
ti
for ti in zip(df['text'])
]

Expand Down
2 changes: 1 addition & 1 deletion core/graphgps/utility/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def init_cfg_test():
'split_labels': True,
'device': 'cpu',
'split_index': [0.8, 0.15, 0.05],
'method': 'tfidf'
'method': 'w2v'
},
'train': {
'device': 'cpu'
Expand Down

0 comments on commit eb53bc8

Please sign in to comment.