Skip to content

Commit

Permalink
update load nc
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 11, 2024
1 parent 9f3e757 commit 019df33
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions core/data_utils/load_data_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def get_node_mask_ogb(num_nodes: int, idx_splits: Dict[str, torch.Tensor]) -> tu
return train_mask, val_mask, test_mask



# Function to parse Cora dataset
def load_graph_arxiv23() -> Data:
return torch.load(FILE_PATH + 'core/dataset/arxiv_2023/graph.pt')
Expand Down Expand Up @@ -205,17 +204,18 @@ def load_text_cora(data_citeid) -> List[str]:


# Function to parse PubMed dataset

def load_graph_product():
raise NotImplementedError
# Add your implementation here


def load_text_product() -> List[str]:
text = pd.read_csv(FILE_PATH + 'core/dataset/ogbn_products_orig/ogbn-products_subset.csv')
text = [f'Product:{ti}; Description: {cont}\n'for ti,
cont in zip(text['title'], text['content'])]
return text


# Function to parse PubMed dataset
def load_tag_product() -> Tuple[Data, List[str]]:
data = torch.load(FILE_PATH + 'core/dataset/ogbn_products_orig/ogbn-products_subset.pt')
Expand Down Expand Up @@ -345,6 +345,7 @@ def load_graph_pubmed(use_mask) -> Data:
edge_attrs = None,
graph_attrs = None
)


# Function to parse PubMed dataset
def load_text_pubmed() -> List[str]:
Expand Down Expand Up @@ -379,8 +380,7 @@ def load_text_ogbn_arxiv():
'Title: ' + ti + '\n' + 'Abstract: ' + ab
for ti, ab in zip(df['title'], df['abs'])
]




def load_graph_ogbn_arxiv(use_mask):
dataset = PygNodePropPredDataset(root='./generated_dataset',
Expand Down Expand Up @@ -448,6 +448,7 @@ def load_graph_citationv8() -> Data:
graph = from_dgl(graph)
return graph


def load_embedded_citationv8(method) -> Data:
return torch.load(FILE_PATH + f'core/dataset/citationv8/citationv8_{method}.pt')

Expand Down Expand Up @@ -521,7 +522,6 @@ def load_text_pwc_small(method) -> List[str]:
return pd.read_csv(FILE_PATH + f'core/dataset/pwc_small/pwc_{method}_small_text.csv')



def extract_lcc_pwc_undir() -> Data:
# return the largest connected components with text attrs
graph = torch.load(FILE_PATH+'core/dataset/pwc_large/pwc_tfidf_large_undir.pt')
Expand Down

0 comments on commit 019df33

Please sign in to comment.