From eff2c10c2e76c735a70a6b995b571213adffbbb7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 25 Aug 2021 18:14:04 -0700 Subject: [PATCH] fix parameterlist dataparallel issue --- axial_attention/axial_attention.py | 12 ++++++------ setup.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/axial_attention/axial_attention.py b/axial_attention/axial_attention.py index a13803d..5eb6b6f 100644 --- a/axial_attention/axial_attention.py +++ b/axial_attention/axial_attention.py @@ -104,18 +104,18 @@ def __init__(self, dim, shape, emb_dim_index = 1): total_dimensions = len(shape) + 2 ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index] - for axial_dim, axial_dim_index in zip(shape, ax_dim_indexes): + self.num_axials = len(shape) + + for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)): shape = [1] * total_dimensions shape[emb_dim_index] = dim shape[axial_dim_index] = axial_dim parameter = nn.Parameter(torch.randn(*shape)) - parameters.append(parameter) - - self.params = nn.ParameterList(parameters) + setattr(self, f'param_{i}', parameter) def forward(self, x): - for param in self.params: - x = x + param + for i in range(self.num_axials): + x = x + getattr(self, f'param_{i}') return x # attention diff --git a/setup.py b/setup.py index 72c7b95..d25dc9f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'axial_attention', packages = find_packages(), - version = '0.6.0', + version = '0.6.1', license='MIT', description = 'Axial Attention', author = 'Phil Wang',