We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在参考作者代码对官方pipeline进行修改之后,终于统一了数据格式的问题,但是后来又遇到了一个问题,与合理的gt数目有关,源代码是这样的: # 如果有预测框(花心大萝卜)匹配到了1个以上的gt时,做特殊处理。 if (anchor_matching_gt > 1).float().sum() > 0: # 首先,找到与花心大萝卜具有最小cost的gt。 # 找到 花心大萝卜 的下标(这是在anchor_matching_gt.shape[N, A]中的下标)。假设有R个花心大萝卜。 indexes = torch.where(anchor_matching_gt > 1) index = torch.stack((indexes[0], indexes[1]), 1) # [R, 2] 每个花心大萝卜2个坐标。第0个坐标表示第几张图片,第1个坐标表示第几个格子。 cost_t = cost.permute(0, 2, 1) # [N, G, A] -> [N, A, G] 转置好提取其cost cost2 = self.gather_nd(cost_t, index) # [R, G] 抽出 R个花心大萝卜 与 gt 两两之间的cost。 cost2 = cost2.permute(1, 0) # [G, R] gt 与 R个花心大萝卜 两两之间的cost。 cost_argmin = cost2.argmin(axis=0) # [R, ] 为 每个花心大萝卜 找到 与其cost最小的gt 的下标 我的代码会在进入这个判断后稳定跑飞,报错位置在求取下标index这一步 indexes = torch.where(anchor_matching_gt > 1) 报错信息为: RuntimeError: numel: integer multiplication overflow 不太清楚这个where计算量在哪里,感觉非常的迷惑,希望能指点一下为什么会出现这个问题
# 如果有预测框(花心大萝卜)匹配到了1个以上的gt时,做特殊处理。 if (anchor_matching_gt > 1).float().sum() > 0: # 首先,找到与花心大萝卜具有最小cost的gt。 # 找到 花心大萝卜 的下标(这是在anchor_matching_gt.shape[N, A]中的下标)。假设有R个花心大萝卜。 indexes = torch.where(anchor_matching_gt > 1) index = torch.stack((indexes[0], indexes[1]), 1) # [R, 2] 每个花心大萝卜2个坐标。第0个坐标表示第几张图片,第1个坐标表示第几个格子。 cost_t = cost.permute(0, 2, 1) # [N, G, A] -> [N, A, G] 转置好提取其cost cost2 = self.gather_nd(cost_t, index) # [R, G] 抽出 R个花心大萝卜 与 gt 两两之间的cost。 cost2 = cost2.permute(1, 0) # [G, R] gt 与 R个花心大萝卜 两两之间的cost。 cost_argmin = cost2.argmin(axis=0) # [R, ] 为 每个花心大萝卜 找到 与其cost最小的gt 的下标
indexes = torch.where(anchor_matching_gt > 1)
RuntimeError: numel: integer multiplication overflow
The text was updated successfully, but these errors were encountered:
No branches or pull requests
在参考作者代码对官方pipeline进行修改之后,终于统一了数据格式的问题,但是后来又遇到了一个问题,与合理的gt数目有关,源代码是这样的:
# 如果有预测框(花心大萝卜)匹配到了1个以上的gt时,做特殊处理。 if (anchor_matching_gt > 1).float().sum() > 0: # 首先,找到与花心大萝卜具有最小cost的gt。 # 找到 花心大萝卜 的下标(这是在anchor_matching_gt.shape[N, A]中的下标)。假设有R个花心大萝卜。 indexes = torch.where(anchor_matching_gt > 1) index = torch.stack((indexes[0], indexes[1]), 1) # [R, 2] 每个花心大萝卜2个坐标。第0个坐标表示第几张图片,第1个坐标表示第几个格子。 cost_t = cost.permute(0, 2, 1) # [N, G, A] -> [N, A, G] 转置好提取其cost cost2 = self.gather_nd(cost_t, index) # [R, G] 抽出 R个花心大萝卜 与 gt 两两之间的cost。 cost2 = cost2.permute(1, 0) # [G, R] gt 与 R个花心大萝卜 两两之间的cost。 cost_argmin = cost2.argmin(axis=0) # [R, ] 为 每个花心大萝卜 找到 与其cost最小的gt 的下标
我的代码会在进入这个判断后稳定跑飞,报错位置在求取下标index这一步
indexes = torch.where(anchor_matching_gt > 1)
报错信息为:
RuntimeError: numel: integer multiplication overflow
不太清楚这个where计算量在哪里,感觉非常的迷惑,希望能指点一下为什么会出现这个问题
The text was updated successfully, but these errors were encountered: