Skip to content

Commit

Permalink
feat: split embedding propagation calculations across multiple GPUs #20
Browse files Browse the repository at this point in the history
  • Loading branch information
twndus committed Jun 23, 2024
1 parent 19b9240 commit 0bb5e48
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 25 deletions.
6 changes: 4 additions & 2 deletions data/datasets/ngcf_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def _set_laplacian_matrix(self, df):
diagonal_degree_matrix = torch.tensor(diagonal_degree_matrix).float().to('cuda')
adjacency_matrix = torch.tensor(adjacency_matrix).float().to('cuda')
self.laplacian_matrix = torch.matmul(diagonal_degree_matrix, adjacency_matrix)
adjacency_matrix = adjacency_matrix.cpu().detach()
del adjacency_matrix
self.laplacian_matrix = torch.matmul(self.laplacian_matrix, diagonal_degree_matrix)
self.laplacian_matrix = self.laplacian_matrix.to(self.cfg.device)
self.laplacian_matrix = self.laplacian_matrix.to('cpu')
logger.info('done...')

del diagonal_degree_matrix

def preprocess(self) -> pd.DataFrame:
df = super().preprocess()
self._set_laplacian_matrix(df)
Expand Down
81 changes: 61 additions & 20 deletions models/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

class NGCF(BaseModel):

def __init__(self, cfg, num_users, num_items, laplacian_matrix):
def __init__(self, cfg, num_users, num_items): #, laplacian_matrix):
super().__init__()
self.cfg = cfg
self.num_users = num_users
self.num_items = num_items
self.laplacian_matrix = laplacian_matrix
# self.laplacian_matrix = laplacian_matrix
self.embedding = nn.Embedding(
num_users+num_items, cfg.embed_size, dtype=torch.float32)

Expand All @@ -27,11 +27,11 @@ def _init_weights(self):
if isinstance(child, nn.Embedding):
nn.init.xavier_uniform_(child.weight)

def forward(self, user_id, item_id):
def forward(self, user_id, item_id, laplacian_matrix):
user_embed_list, item_embed_list = [self.embedding(user_id),], [self.embedding(self.num_users+item_id)]
last_embed = self.embedding.weight
for w1, w2 in zip(self.W1, self.W2):
last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2)
last_embed: torch.Tensor = self.embedding_propagation(last_embed, w1, w2, laplacian_matrix)
user_embed_list.append(last_embed[user_id])
item_embed_list.append(last_embed[self.num_users + item_id])

Expand All @@ -40,27 +40,68 @@ def forward(self, user_id, item_id):

return torch.sum(user_embed * item_embed, dim=1)

def embedding_propagation(self, last_embed: torch.Tensor, w1, w2):
def _embedding_propagation(self, last_embed: torch.Tensor, w1, w2):
identity_matrix = torch.eye(*self.laplacian_matrix.size())
matrix = self.laplacian_matrix.to('cpu') + identity_matrix
matrix = (self.laplacian_matrix.to('cpu') + identity_matrix).to(self.cfg.device)

# split calcuclation GPU memory shortage
chunk_size = 32
embed_list = []
for chunk_idx in range(0, self.num_users + self.num_items, chunk_size):
matrix_concat = matrix[chunk_idx : (chunk_idx + chunk_size)]
term1 = torch.matmul(matrix_concat.to(self.cfg.device), last_embed)
term1 = w1(term1)
# chunk_size = 32
# embed_list = []
# for chunk_idx in range(0, self.num_users + self.num_items, chunk_size):
# matrix_concat = matrix[chunk_idx : (chunk_idx + chunk_size)]
# term1 = torch.matmul(matrix_concat.to(self.cfg.device), last_embed)
# term1 = w1(term1)
#
# laplacian_concat = self.laplacian_matrix[chunk_idx : (chunk_idx + chunk_size)]
# neighbor_embeddings = torch.matmul(laplacian_concat, last_embed)
#
# last_embed_concat = last_embed[chunk_idx : (chunk_idx + chunk_size)]
# term2 = torch.mul(neighbor_embeddings, last_embed_concat)
# term2 = w2(term2)
# embed_list.append(term1 + term2)
#
# embed_list = torch.concat(embed_list, dim=0)

laplacian_concat = self.laplacian_matrix[chunk_idx : (chunk_idx + chunk_size)]
neighbor_embeddings = torch.matmul(laplacian_concat, last_embed)
term1 = torch.matmul(matrix, last_embed)
term1 = w1(term1)

last_embed_concat = last_embed[chunk_idx : (chunk_idx + chunk_size)]
term2 = torch.mul(neighbor_embeddings, last_embed_concat)
term2 = w2(term2)
embed_list.append(term1 + term2)
neighbor_embeddings = torch.matmul(self.laplacian_matrix, last_embed)

embed_list = torch.concat(embed_list, dim=0)
term2 = torch.mul(neighbor_embeddings, last_embed)
term2 = w2(term2)

return nn.functional.leaky_relu(embed_list)
# return nn.functional.leaky_relu(embed_list)
return nn.functional.leaky_relu(term1 + term2)

def embedding_propagation(self, last_embed: torch.Tensor, w1, w2, laplacian_matrix):
device0 = torch.device('cuda:0')
device1 = torch.device('cuda:1')

# Split last_embed into two parts for each GPU
mid = last_embed.size(0) // 2

# Prepare identity matrix and laplacian matrix on each GPU
identity_matrix = torch.eye(last_embed.size(0))
matrix = laplacian_matrix + identity_matrix

# Compute term1 on GPU0
term1_part0 = torch.matmul(matrix.to(device0), last_embed.to(device0))
term1_part0 = w1(term1_part0)

# Compute term2 on GPU1
w2 = w2.to(device1)
neighbor_embeddings = torch.matmul(laplacian_matrix.to(device1), last_embed.to(device1))
term2_part1 = torch.mul(neighbor_embeddings, last_embed.to(device1))
term2_part1 = w2(term2_part1)

# Transfer term2_part1 to GPU0
term2_part1 = term2_part1.to(device0)

# Combine term1 and term2 on GPU0
combined_result = term1_part0 + term2_part1
#combined_result = term1_part0

# Apply activation function
result = nn.functional.leaky_relu(combined_result)

return result # Return result to device0
7 changes: 4 additions & 3 deletions trainers/ngcf_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def __init__(self, cfg: DictConfig, num_items: int, num_users: int, laplacian_ma
logger.info(f'[DEVICE] device = {self.device}')
self.num_items = num_items
self.num_users = num_users
self.model = NGCF(self.cfg, num_users, num_items, laplacian_matrix).to(self.device)
self.model = NGCF(self.cfg, num_users, num_items).to(self.device)
self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr, self.cfg.weight_decay)
self.loss = self._loss()
self.laplacian_matrix = laplacian_matrix

def _loss(self):
return BPRLoss()
Expand Down Expand Up @@ -104,8 +105,8 @@ def train(self, train_dataloader: DataLoader) -> float:
for data in tqdm(train_dataloader):
user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \
data['neg_item'].to(self.device)
pos_pred = self.model(user_id, pos_item)
neg_pred = self.model(user_id, neg_item)
pos_pred = self.model(user_id, pos_item, self.laplacian_matrix)
neg_pred = self.model(user_id, neg_item, self.laplacian_matrix)

self.optimizer.zero_grad()
loss = self.loss(pos_pred, neg_pred)
Expand Down

0 comments on commit 0bb5e48

Please sign in to comment.