diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 07a913721..23a3e86dc 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -43,12 +43,12 @@ def set(self, name: str, value): def update_tensor(self, name: str, value: torch.Tensor): if name not in self._data_dict: - self._data_dict[name] = value.cpu().clone().detach() + self._data_dict[name] = value.clone().detach() else: if not isinstance(self._data_dict[name], torch.Tensor): raise ValueError("{} is not a tensor.".format(name)) self._data_dict[name] = torch.cat( - (self._data_dict[name], value.cpu().clone().detach()), dim=0 + (self._data_dict[name], value.clone().detach()), dim=0 ) def __str__(self): @@ -149,6 +149,7 @@ def eval_batch_collect( positive_i(Torch.Tensor): the positive item id for each user. """ if self.register.need("rec.items"): + # get topk _, topk_idx = torch.topk( scores_tensor, max(self.topk), dim=-1 @@ -156,6 +157,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.items", topk_idx) if self.register.need("rec.topk"): + _, topk_idx = torch.topk( scores_tensor, max(self.topk), dim=-1 ) # n_users x k @@ -167,6 +169,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.topk", result) if self.register.need("rec.meanrank"): + desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True) # get the index of positive items in the ranking list @@ -185,6 +188,7 @@ def eval_batch_collect( self.data_struct.update_tensor("rec.meanrank", result) if self.register.need("rec.score"): + self.data_struct.update_tensor("rec.score", scores_tensor) if self.register.need("data.label"): @@ -219,6 +223,8 @@ def get_data_struct(self): """Get all the evaluation resource that been collected. And reset some of outdated resource. """ + for key in self.data_struct._data_dict: + self.data_struct._data_dict[key] = self.data_struct._data_dict[key].cpu() returned_struct = copy.deepcopy(self.data_struct) for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]: if key in self.data_struct: