diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index 968830d..fb6dd29 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -19,9 +19,10 @@ import torch import torch.distributed as dist -from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig from omegaconf import open_dict +from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig + try: import torch_xla.core.xla_model as xm except ImportError: diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index dae5da3..7c16cc9 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -9,6 +9,7 @@ RotaryPositionMultiHeadedAttention, ) from .fp32_group_norm import Fp32GroupNorm +from .gelu import gelu, gelu_accurate from .grad_multiply import GradMultiply from .gumbel_vector_quantizer import GumbelVectorQuantizer from .layer_norm import Fp32LayerNorm, LayerNorm @@ -16,8 +17,6 @@ from .positional_encoding import RelPositionalEncoding from .same_pad import SamePad, SamePad2d from .transpose_last import TransposeLast -from .gelu import gelu, gelu_accurate - __all__ = [ "Fp32GroupNorm", diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 73c7c69..ad186ef 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -3,10 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict + import torch + from fairseq import utils from fairseq.dataclass.utils import gen_parser_from_dataclass -from collections import defaultdict class FairseqOptimizer(object):