From f77c3fea5e56086186528f58df86aa7d9c3e3eb8 Mon Sep 17 00:00:00 2001 From: HotBento <434365819@qq.com> Date: Fri, 9 Aug 2024 13:22:44 +0800 Subject: [PATCH] Update collector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 只在最终get_data_struct将tensor转移到cpu,极大节约cpu占用。 --- recbole/evaluator/collector.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 07a913721..23a3e86dc 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -43,12 +43,12 @@ def set(self, name: str, value): def update_tensor(self, name: str, value: torch.Tensor): if name not in self._data_dict: - self._data_dict[name] = value.cpu().clone().detach() + self._data_dict[name] = value.clone().detach() else: if not isinstance(self._data_dict[name], torch.Tensor): raise ValueError("{} is not a tensor.".format(name)) self._data_dict[name] = torch.cat( - (self._data_dict[name], value.cpu().clone().detach()), dim=0 + (self._data_dict[name], value.clone().detach()), dim=0 ) def __str__(self): @@ -149,6 +149,7 @@ def eval_batch_collect( positive_i(Torch.Tensor): the positive item id for each user. """ if self.register.need("rec.items"): + # get topk _, topk_idx = torch.topk( scores_tensor, max(self.topk), dim=-1 @@ -156,6 +157,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.items", topk_idx) if self.register.need("rec.topk"): + _, topk_idx = torch.topk( scores_tensor, max(self.topk), dim=-1 ) # n_users x k @@ -167,6 +169,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.topk", result) if self.register.need("rec.meanrank"): + desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True) # get the index of positive items in the ranking list @@ -185,6 +188,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.meanrank", result) if self.register.need("rec.score"): + self.data_struct.update_tensor("rec.score", scores_tensor) if self.register.need("data.label"): @@ -219,6 +223,8 @@ def get_data_struct(self): """Get all the evaluation resource that been collected. And reset some of outdated resource. """ + for key in self.data_struct._data_dict: + self.data_struct._data_dict[key] = self.data_struct._data_dict[key].cpu() returned_struct = copy.deepcopy(self.data_struct) for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]: if key in self.data_struct: