diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index e3f316b..33b6f2e 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -18,16 +18,15 @@ import numpy as np import torch +from omegaconf import DictConfig, OmegaConf, open_dict + from fairseq.data import data_utils from fairseq.dataclass.configs import CheckpointConfig -from fairseq.dataclass.utils import ( - convert_namespace_to_omegaconf, - overwrite_args_by_name, -) +from fairseq.dataclass.utils import (convert_namespace_to_omegaconf, + overwrite_args_by_name) from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig, OmegaConf, open_dict logger = logging.getLogger(__name__) @@ -100,9 +99,9 @@ def is_better(a, b): cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix ) ] = worst_best is None or is_better(val_loss, worst_best) - checkpoint_conds[ - "checkpoint_last{}.pt".format(suffix) - ] = not cfg.no_last_checkpoints + checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = ( + not cfg.no_last_checkpoints + ) extra_state = { "train_iterator": epoch_itr.state_dict(), @@ -116,7 +115,9 @@ def is_better(a, b): # attributes if hasattr(trainer.task, "get_checkpoint_dict"): extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} - logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") + logger.info( + f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint" + ) if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index bc8c938..d45dd71 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -8,6 +8,7 @@ from multiprocessing import Pool import torch + # from fairseq import utils # from fairseq.data import data_utils # from fairseq.file_chunker_utils import Chunker, find_offsets @@ -358,7 +359,7 @@ def merge_result(counter): chunks = zip(offsets, offsets[1:]) pool = Pool(processes=num_workers) results = [] - for (start_offset, end_offset) in chunks: + for start_offset, end_offset in chunks: results.append( pool.apply_async( Dictionary._add_file_to_dictionary_single_worker, diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py index 7cbe00a..074367b 100644 --- a/fairseq/data/encoders/__init__.py +++ b/fairseq/data/encoders/__init__.py @@ -9,7 +9,6 @@ from fairseq import registry - build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( "--tokenizer", default=None, diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 2bde7fc..6239eca 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -4,8 +4,10 @@ # LICENSE file in the root directory of this source tree. import logging + import numpy as np import torch.utils.data + from fairseq.data import data_utils logger = logging.getLogger(__name__) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 6a5a42a..a739232 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -15,8 +15,8 @@ import numpy as np import torch -from fairseq.data import data_utils +from fairseq.data import data_utils logger = logging.getLogger(__name__) diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py index 25408d2..2d241c3 100644 --- a/fairseq/dataclass/__init__.py +++ b/fairseq/dataclass/__init__.py @@ -6,7 +6,6 @@ from .configs import FairseqDataclass from .constants import ChoiceEnum - __all__ = [ "FairseqDataclass", "ChoiceEnum", diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index ba4d7e5..f1194c3 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -11,17 +11,15 @@ import torch from omegaconf import II, MISSING -from fairseq.dataclass.constants import ( - DATASET_IMPL_CHOICES, - DDP_BACKEND_CHOICES, - DDP_COMM_HOOK_CHOICES, - GENERATION_CONSTRAINTS_CHOICES, - GENERATION_DECODING_FORMAT_CHOICES, - LOG_FORMAT_CHOICES, - PIPELINE_CHECKPOINT_CHOICES, - PRINT_ALIGNMENT_CHOICES, - ZERO_SHARDING_CHOICES, -) +from fairseq.dataclass.constants import (DATASET_IMPL_CHOICES, + DDP_BACKEND_CHOICES, + DDP_COMM_HOOK_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, + ZERO_SHARDING_CHOICES) @dataclass @@ -108,7 +106,7 @@ class CommonConfig(FairseqDataclass): "help": "log progress every N batches (when progress bar is disabled)" }, ) - log_format: Optional[LOG_FORMAT_CHOICES] = field( # type: ignore + log_format: Optional[LOG_FORMAT_CHOICES] = field( # type: ignore default=None, metadata={"help": "log format to use"} ) log_file: Optional[str] = field( @@ -298,10 +296,10 @@ class DistributedTrainingConfig(FairseqDataclass): "help": "do not spawn multiple processes even if multiple GPUs are visible" }, ) - ddp_backend: DDP_BACKEND_CHOICES = field( # type: ignore + ddp_backend: DDP_BACKEND_CHOICES = field( # type: ignore default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} ) - ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( # type: ignore + ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( # type: ignore default="none", metadata={"help": "communication hook"} ) bucket_cap_mb: int = field( @@ -428,11 +426,11 @@ class DistributedTrainingConfig(FairseqDataclass): "equal the length of the --pipeline-decoder-balance argument" }, ) - pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( # type: ignore + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( # type: ignore default="never", metadata={"help": "checkpointing mode for pipeline model parallelism"}, ) - zero_sharding: ZERO_SHARDING_CHOICES = field( # type: ignore + zero_sharding: ZERO_SHARDING_CHOICES = field( # type: ignore default="none", metadata={"help": "ZeRO sharding"} ) fp16: bool = II("common.fp16") @@ -488,7 +486,7 @@ class DatasetConfig(FairseqDataclass): "help": "maximum sequence length in batch will be a multiplier of this value" }, ) - dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( # type: ignore + dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( # type: ignore default=None, metadata={"help": "output dataset implementation"} ) data_buffer_size: int = field( @@ -921,7 +919,7 @@ class GenerationConfig(FairseqDataclass): "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" }, ) - constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( # type: ignore + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( # type: ignore default=None, metadata={ "help": "enables lexically constrained decoding", @@ -944,7 +942,7 @@ class GenerationConfig(FairseqDataclass): default=-1.0, metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, ) - print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( # type: ignore + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( # type: ignore default=None, metadata={ "help": "if set, uses attention feedback to compute and print alignment to source tokens " @@ -1012,7 +1010,7 @@ class GenerationConfig(FairseqDataclass): }, ) # special decoding format for advanced decoding. - decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( # type: ignore + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( # type: ignore default=None, metadata={"help": "special decoding format for advanced decoding."}, ) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f6467d5..f59aa06 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -13,11 +13,12 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type -from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.configs import FairseqConfig from hydra.core.global_hydra import GlobalHydra from hydra.experimental import compose, initialize -from omegaconf import DictConfig, OmegaConf, open_dict, _utils +from omegaconf import DictConfig, OmegaConf, _utils, open_dict + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig logger = logging.getLogger(__name__) @@ -344,7 +345,8 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: no_dc = True if hasattr(args, "arch"): - from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY + from fairseq.models import (ARCH_MODEL_NAME_REGISTRY, + ARCH_MODEL_REGISTRY) if args.arch in ARCH_MODEL_REGISTRY: m_cls = ARCH_MODEL_REGISTRY[args.arch] diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index d278ccb..0f2feb7 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -3,14 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .fully_sharded_data_parallel import ( - fsdp_enable_wrap, - fsdp_wrap, - FullyShardedDataParallel, -) +from .fully_sharded_data_parallel import (FullyShardedDataParallel, + fsdp_enable_wrap, fsdp_wrap) __all__ = [ "fsdp_enable_wrap", "fsdp_wrap", "FullyShardedDataParallel", -] \ No newline at end of file +] diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 7656d2e..50eeced 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -7,12 +7,13 @@ from typing import Optional import torch + from fairseq.dataclass.configs import DistributedTrainingConfig from fairseq.distributed import utils as dist_utils - try: - from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP # type: ignore + from fairscale.nn.data_parallel import \ + FullyShardedDataParallel as FSDP # type: ignore has_FSDP = True except ImportError: @@ -91,7 +92,7 @@ def size(self) -> int: @contextlib.contextmanager def fsdp_enable_wrap(cfg: DistributedTrainingConfig): try: - from fairscale.nn import enable_wrap # type: ignore + from fairscale.nn import enable_wrap # type: ignore except ImportError: raise ImportError( "Cannot find FullyShardedDataParallel. " @@ -131,7 +132,7 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): min_num_params (int, Optional): minimum number of layer params to wrap """ try: - from fairscale.nn import wrap # type: ignore + from fairscale.nn import wrap # type: ignore if min_num_params is not None: num_params = sum(p.numel() for p in module.parameters()) diff --git a/fairseq/file_io.py b/fairseq/file_io.py index 8eca70a..f3f37fe 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -10,7 +10,6 @@ import shutil from typing import List, Optional - logger = logging.getLogger(__file__) diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index a7b0b40..965deed 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -18,7 +18,6 @@ from .meters import * - # Aggregation contexts are considered "active" when inside the scope # created by the :func:`aggregate` context manager. _aggregators = OrderedDict() @@ -329,7 +328,7 @@ def load_state_dict(state_dict): def xla_metrics_report(): try: - import torch_xla.debug.metrics as met # type: ignore + import torch_xla.debug.metrics as met # type: ignore print(met.metrics_report()) except ImportError: diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 13b73d6..fc69226 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -6,9 +6,10 @@ from typing import Dict, List, Optional, Tuple import torch.nn as nn -from fairseq import utils from torch import Tensor +from fairseq import utils + class FairseqDecoder(nn.Module): """Base class for decoders.""" diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py index 08cbde1..fd8da62 100644 --- a/fairseq/models/fairseq_encoder.py +++ b/fairseq/models/fairseq_encoder.py @@ -9,7 +9,6 @@ import torch.nn as nn from torch import Tensor - EncoderOut = NamedTuple( "EncoderOut", [ diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 6e4ae6d..6b606f4 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -13,15 +13,14 @@ import torch import torch.nn as nn import torch.nn.functional as F +from omegaconf import DictConfig +from torch import Tensor + from fairseq import utils from fairseq.data import Dictionary -from fairseq.dataclass.utils import ( - convert_namespace_to_omegaconf, - gen_parser_from_dataclass, -) +from fairseq.dataclass.utils import (convert_namespace_to_omegaconf, + gen_parser_from_dataclass) from fairseq.models import FairseqDecoder, FairseqEncoder -from omegaconf import DictConfig -from torch import Tensor logger = logging.getLogger(__name__) diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py index 8c4b8d0..958628c 100644 --- a/fairseq/models/hubert/hubert.py +++ b/fairseq/models/hubert/hubert.py @@ -17,18 +17,14 @@ from fairseq.data.dictionary import Dictionary from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import BaseFairseqModel, register_model -from fairseq.models.wav2vec.wav2vec2 import ( - EXTRACTOR_MODE_CHOICES, - MASKING_DISTRIBUTION_CHOICES, - LAYER_TYPE_CHOICES, - ConvFeatureExtractionModel, - TransformerEncoder, -) +from fairseq.models.wav2vec.wav2vec2 import (EXTRACTOR_MODE_CHOICES, + LAYER_TYPE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder) from fairseq.modules import GradMultiply, LayerNorm -from fairseq.tasks.hubert_pretraining import ( - HubertPretrainingConfig, - HubertPretrainingTask, -) +from fairseq.tasks.hubert_pretraining import (HubertPretrainingConfig, + HubertPretrainingTask) logger = logging.getLogger(__name__) @@ -37,7 +33,7 @@ class HubertConfig(FairseqDataclass): label_rate: float = II("task.label_rate") - extractor_mode: EXTRACTOR_MODE_CHOICES = field( # type: ignore + extractor_mode: EXTRACTOR_MODE_CHOICES = field( # type: ignore default="default", metadata={ "help": "mode for feature extractor. default has a single group " @@ -57,10 +53,10 @@ class HubertConfig(FairseqDataclass): encoder_attention_heads: int = field( default=12, metadata={"help": "num encoder attention heads"} ) - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore default="gelu", metadata={"help": "activation function to use"} ) - layer_type: LAYER_TYPE_CHOICES = field( # type: ignore + layer_type: LAYER_TYPE_CHOICES = field( # type: ignore default="transformer", metadata={"help": "layer type in encoder"} ) @@ -133,7 +129,7 @@ class HubertConfig(FairseqDataclass): default=0.65, metadata={"help": "probability of replacing a token with mask"}, ) - mask_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length"} ) mask_other: float = field( @@ -161,7 +157,7 @@ class HubertConfig(FairseqDataclass): default=0.0, metadata={"help": "probability of replacing a feature with 0"}, ) - mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length for channel masking"}, ) diff --git a/fairseq/models/wav2vec/utils.py b/fairseq/models/wav2vec/utils.py index dd52d86..38f6861 100644 --- a/fairseq/models/wav2vec/utils.py +++ b/fairseq/models/wav2vec/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math + import torch.nn.functional as F diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index d940ffe..dad59e5 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -16,19 +16,13 @@ from fairseq.data.data_utils import compute_mask_indices from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.distributed import fsdp_wrap +from fairseq.distributed.fully_sharded_data_parallel import \ + FullyShardedDataParallel from fairseq.models import BaseFairseqModel, register_model -from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel -from fairseq.modules import ( - Fp32GroupNorm, - Fp32LayerNorm, - GradMultiply, - GumbelVectorQuantizer, - LayerNorm, - MultiheadAttention, - RelPositionalEncoding, - SamePad, - TransposeLast, -) +from fairseq.modules import (Fp32GroupNorm, Fp32LayerNorm, GradMultiply, + GumbelVectorQuantizer, LayerNorm, + MultiheadAttention, RelPositionalEncoding, + SamePad, TransposeLast) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer from fairseq.modules.transformer_sentence_encoder import init_bert_params @@ -43,7 +37,7 @@ @dataclass class Wav2Vec2Config(FairseqDataclass): - extractor_mode: EXTRACTOR_MODE_CHOICES = field( # type: ignore + extractor_mode: EXTRACTOR_MODE_CHOICES = field( # type: ignore default="default", metadata={ "help": "mode for feature extractor. default has a single group norm with d " @@ -63,10 +57,10 @@ class Wav2Vec2Config(FairseqDataclass): encoder_attention_heads: int = field( default=12, metadata={"help": "num encoder attention heads"} ) - activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( # type: ignore default="gelu", metadata={"help": "activation function to use"} ) - layer_type: LAYER_TYPE_CHOICES = field( # type: ignore + layer_type: LAYER_TYPE_CHOICES = field( # type: ignore default="transformer", metadata={"help": "layer type in encoder"} ) # dropouts @@ -160,7 +154,7 @@ class Wav2Vec2Config(FairseqDataclass): mask_prob: float = field( default=0.65, metadata={"help": "probability of replacing a token with mask"} ) - mask_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length"} ) mask_other: float = field( @@ -197,7 +191,7 @@ class Wav2Vec2Config(FairseqDataclass): default=0.0, metadata={"help": "probability of replacing a feature with 0"} ) mask_channel_before: bool = False - mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( # type: ignore default="static", metadata={"help": "how to choose mask length for channel masking"}, ) @@ -291,15 +285,9 @@ class Wav2Vec2Config(FairseqDataclass): fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) # Adapter num - adp_num: int = field( - default=-1 - ) - adp_dim: int = field( - default=64 - ) - adp_act_fn: str = field( - default="relu" - ) + adp_num: int = field(default=-1) + adp_dim: int = field(default=64) + adp_act_fn: str = field(default="relu") adp_trf_idx: str = field( default="all", ) @@ -975,7 +963,9 @@ def build_encoder_layer(self, args: Wav2Vec2Config, **kwargs): if args.adp_trf_idx == "all": use_adp = True else: - adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")])) + adp_trf_idx = list( + range(*[int(g) for g in args.adp_trf_idx.split(":")]) + ) if kwargs.get("layer_idx", None) in adp_trf_idx: use_adp = True if use_adp: @@ -1009,7 +999,12 @@ def build_encoder_layer(self, args: Wav2Vec2Config, **kwargs): layer = checkpoint_wrapper(layer) return layer - def __init__(self, args: Wav2Vec2Config, skip_pos_conv: bool = False, override_encoder_layer: int = None): + def __init__( + self, + args: Wav2Vec2Config, + skip_pos_conv: bool = False, + override_encoder_layer: int = None, + ): super().__init__() self.dropout = args.dropout @@ -1052,9 +1047,11 @@ def make_conv_block(e, k, g, l): self.embedding_dim, args.conv_pos, args.conv_pos_groups, - is_batch_norm=args.conv_pos_batch_norm - if hasattr(args, "conv_pos_batch_norm") - else False, + is_batch_norm=( + args.conv_pos_batch_norm + if hasattr(args, "conv_pos_batch_norm") + else False + ), ) if override_encoder_layer is None: @@ -1063,7 +1060,10 @@ def make_conv_block(e, k, g, l): encoder_layers = override_encoder_layer self.layers = nn.ModuleList( - [self.build_encoder_layer(args, layer_idx=ii) for ii in range(encoder_layers)] + [ + self.build_encoder_layer(args, layer_idx=ii) + for ii in range(encoder_layers) + ] ) self.layer_norm_first = args.layer_norm_first self.layer_norm = LayerNorm(self.embedding_dim) @@ -1127,9 +1127,8 @@ def extract_features( if isinstance(layer, FullyShardedDataParallel): layer_check = layer.unwrapped_module if (corpus_key is None) or ( - not isinstance(layer_check, ( - TransformerSentenceEncoderWithAdapterLayer, - ) + not isinstance( + layer_check, (TransformerSentenceEncoderWithAdapterLayer,) ) ): x, (z, lr) = layer( @@ -1405,7 +1404,6 @@ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): else: raise ValueError(f"unsupported {act_fn}") - self.input_dim = input_dim self.reset_parameters() @@ -1426,7 +1424,7 @@ def reset_parameters(self): def forward(self, x, adapter_id): ii = adapter_id h = x - h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]) + h = F.layer_norm(h, (self.input_dim,), self.ln_W[ii], self.ln_b[ii]) h = F.linear(h, self.W_a[ii], self.b_a[ii]) h = self.act_fn(h) h = F.linear(h, self.W_b[ii], self.b_b[ii]) @@ -1434,8 +1432,9 @@ def forward(self, x, adapter_id): return outputs def extra_repr(self): - return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim)) - + return "adapter={}, input_dim={}, hidden_dim={}".format( + self.adapter_num, self.input_dim, self.hidden_dim + ) class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): @@ -1468,12 +1467,13 @@ def __init__( activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first, - ) self.adapter_num = adapter_num self.adapter_dim = adapter_dim - self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn) + self.adapter_layer = AdapterFast( + adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn + ) def forward( self, diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 4679bd7..0dd7d3f 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -3,20 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .espnet_multihead_attention import (ESPNETMultiHeadedAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention) from .fp32_group_norm import Fp32GroupNorm from .grad_multiply import GradMultiply from .gumbel_vector_quantizer import GumbelVectorQuantizer from .layer_norm import Fp32LayerNorm, LayerNorm from .multihead_attention import MultiheadAttention +from .positional_encoding import RelPositionalEncoding from .same_pad import SamePad, SamePad2d -from .positional_encoding import ( - RelPositionalEncoding, -) -from .espnet_multihead_attention import ( - ESPNETMultiHeadedAttention, - RelPositionMultiHeadedAttention, - RotaryPositionMultiHeadedAttention, -) from .transpose_last import TransposeLast __all__ = [ diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index aa0b592..0717271 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -8,6 +8,7 @@ import torch import torch.utils.checkpoint as checkpoint + from fairseq import utils diff --git a/fairseq/modules/conformer_layer.py b/fairseq/modules/conformer_layer.py index 964af24..9accbc7 100644 --- a/fairseq/modules/conformer_layer.py +++ b/fairseq/modules/conformer_layer.py @@ -8,13 +8,10 @@ import torch -from fairseq.modules import ( - ESPNETMultiHeadedAttention, - LayerNorm, - MultiheadAttention, - RelPositionMultiHeadedAttention, - RotaryPositionMultiHeadedAttention, -) +from fairseq.modules import (ESPNETMultiHeadedAttention, LayerNorm, + MultiheadAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention) from fairseq.utils import get_activation_fn diff --git a/fairseq/modules/espnet_multihead_attention.py b/fairseq/modules/espnet_multihead_attention.py index 82bc0d7..20f13c3 100644 --- a/fairseq/modules/espnet_multihead_attention.py +++ b/fairseq/modules/espnet_multihead_attention.py @@ -12,9 +12,7 @@ from torch import nn from fairseq.modules.rotary_positional_embedding import ( - RotaryPositionalEmbedding, - apply_rotary_pos_emb, -) + RotaryPositionalEmbedding, apply_rotary_pos_emb) class ESPNETMultiHeadedAttention(nn.Module): diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py index 3cddca7..ebce274 100644 --- a/fairseq/modules/fairseq_dropout.py +++ b/fairseq/modules/fairseq_dropout.py @@ -9,7 +9,6 @@ import torch.nn as nn import torch.nn.functional as F - logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def make_generation_fast_( name: str, retain_dropout: bool = False, retain_dropout_modules: Optional[List[str]] = None, - **kwargs + **kwargs, ): if retain_dropout: if retain_dropout_modules is not None and self.module_name is None: diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 262132d..c1bf690 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -20,9 +20,10 @@ _xformers_available = False from fairseq import utils +from fairseq.models.fairseq_incremental_decoder import \ + FairseqIncrementalDecoder from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise -from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder # TODO: move this into xformers? @@ -194,33 +195,15 @@ def _get_reserve_head_index(self, num_heads_to_keep: int): start_idx = i * self.head_dim end_idx = (i + 1) * self.head_dim k_proj_heads_norm.append( - torch.sum( - torch.abs( - self.k_proj.weight[ - start_idx:end_idx, - ] - ) - ).tolist() + torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx,])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() ) q_proj_heads_norm.append( - torch.sum( - torch.abs( - self.q_proj.weight[ - start_idx:end_idx, - ] - ) - ).tolist() + torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx,])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() ) v_proj_heads_norm.append( - torch.sum( - torch.abs( - self.v_proj.weight[ - start_idx:end_idx, - ] - ) - ).tolist() + torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx,])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() ) @@ -251,26 +234,14 @@ def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): for ele in reserve_head_index: start_idx, end_idx = ele - new_q_weight.append( - self.q_proj.weight[ - start_idx:end_idx, - ] - ) + new_q_weight.append(self.q_proj.weight[start_idx:end_idx,]) new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) - new_k_weight.append( - self.k_proj.weight[ - start_idx:end_idx, - ] - ) + new_k_weight.append(self.k_proj.weight[start_idx:end_idx,]) new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) - new_v_weight.append( - self.v_proj.weight[ - start_idx:end_idx, - ] - ) + new_v_weight.append(self.v_proj.weight[start_idx:end_idx,]) new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) diff --git a/fairseq/modules/positional_encoding.py b/fairseq/modules/positional_encoding.py index 67f6353..cf6aba3 100644 --- a/fairseq/modules/positional_encoding.py +++ b/fairseq/modules/positional_encoding.py @@ -3,9 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch.nn as nn import math + import torch +import torch.nn as nn class PositionalEncoding(nn.Module): diff --git a/fairseq/modules/rotary_positional_embedding.py b/fairseq/modules/rotary_positional_embedding.py index b74028b..61fe3b3 100644 --- a/fairseq/modules/rotary_positional_embedding.py +++ b/fairseq/modules/rotary_positional_embedding.py @@ -34,6 +34,7 @@ def forward(self, x, seq_len: int = 0): self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) return self.cos_cached, self.sin_cached + # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 018e1d2..92c3829 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch.nn as nn + from fairseq.modules import MultiheadAttention from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ diff --git a/fairseq/optim/amp_optimizer.py b/fairseq/optim/amp_optimizer.py index cfe57d0..94ef719 100644 --- a/fairseq/optim/amp_optimizer.py +++ b/fairseq/optim/amp_optimizer.py @@ -6,9 +6,10 @@ import logging import torch -from fairseq import optim from omegaconf import DictConfig +from fairseq import optim + logger = logging.getLogger(__name__) diff --git a/fairseq/registry.py b/fairseq/registry.py index 904ffcd..528b426 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -4,13 +4,14 @@ # LICENSE file in the root directory of this source tree. from argparse import Namespace - from typing import Union -from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent + from hydra.core.config_store import ConfigStore from omegaconf import DictConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent + REGISTRIES = {} diff --git a/fairseq/search.py b/fairseq/search.py index c7378bb..7a63306 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -4,18 +4,16 @@ # LICENSE file in the root directory of this source tree. import math - from typing import List, Optional import torch import torch.nn as nn -from fairseq.token_generation_constraints import ( - ConstraintState, - OrderedConstraintState, - UnorderedConstraintState, -) from torch import Tensor +from fairseq.token_generation_constraints import (ConstraintState, + OrderedConstraintState, + UnorderedConstraintState) + class Search(nn.Module): def __init__(self, tgt_dict): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index e39d1d6..7322c9e 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -10,14 +10,15 @@ from typing import Any, Callable, Dict, List import torch +from omegaconf import DictConfig + from fairseq import search, tokenizer, utils -from fairseq.logging import metrics -from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.data import (Dictionary, FairseqDataset, data_utils, encoders, + iterators) from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.logging import metrics from fairseq.optim.amp_optimizer import AMPOptimizer -from omegaconf import DictConfig - logger = logging.getLogger(__name__) @@ -411,10 +412,8 @@ def build_generator( compute_alignment=getattr(args, "print_alignment", False), ) - from fairseq.sequence_generator import ( - SequenceGenerator, - SequenceGeneratorWithAlignment, - ) + from fairseq.sequence_generator import (SequenceGenerator, + SequenceGeneratorWithAlignment) # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py index d7e58d3..0317b12 100644 --- a/fairseq/tasks/hubert_pretraining.py +++ b/fairseq/tasks/hubert_pretraining.py @@ -8,16 +8,16 @@ import logging import os import sys +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import numpy as np +from omegaconf import MISSING -from dataclasses import dataclass, field from fairseq.data import Dictionary from fairseq.dataclass.configs import FairseqDataclass from fairseq.tasks import register_task from fairseq.tasks.fairseq_task import FairseqTask -from omegaconf import MISSING logger = logging.getLogger(__name__) diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py index 42131f7..8c4d694 100644 --- a/fairseq/tokenizer.py +++ b/fairseq/tokenizer.py @@ -5,7 +5,6 @@ import re - SPACE_NORMALIZER = re.compile(r"\s+") diff --git a/fairseq/utils.py b/fairseq/utils.py index 72bd35f..88d37e0 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -23,14 +23,14 @@ from fairseq.modules.multihead_attention import MultiheadAttention try: - from amp_C import multi_tensor_l2norm # type: ignore + from amp_C import multi_tensor_l2norm # type: ignore multi_tensor_l2norm_available = True except ImportError: multi_tensor_l2norm_available = False try: - import torch_xla.core.xla_model as xm # type: ignore + import torch_xla.core.xla_model as xm # type: ignore except ImportError: xm = None @@ -128,7 +128,7 @@ def _move_to_cpu(tensor): def move_to_tpu(sample): - import torch_xla.core.xla_model as xm # type: ignore + import torch_xla.core.xla_model as xm # type: ignore device = xm.xla_device() @@ -714,8 +714,8 @@ def get_tpu_device(): def tpu_data_loader(itr): - import torch_xla.core.xla_model as xm # type: ignore - import torch_xla.distributed.parallel_loader as pl # type: ignore + import torch_xla.core.xla_model as xm # type: ignore + import torch_xla.distributed.parallel_loader as pl # type: ignore from fairseq.data import iterators @@ -746,7 +746,7 @@ def index_put(tensor, indices, value): def xla_device_to_cpu(dat): - import torch_xla.core.xla_model as xm # type: ignore + import torch_xla.core.xla_model as xm # type: ignore return xm._maybe_convert_to_cpu(dat) @@ -890,13 +890,14 @@ def train_step(self, sample ....): * Need to launch train.py locally (cannot submit jobs) """ try: - import jurigged # type: ignore + import jurigged # type: ignore except ImportError as e: logger.warning("Please install jurigged: pip install jurigged[develoop]") raise e - from fairseq.distributed import utils as distributed_utils import traceback + from fairseq.distributed import utils as distributed_utils + def hotreload_decorator(func): assert callable(func), f"not callable: {func}" jname = name or func.__name__ diff --git a/fairseq/version.txt b/fairseq/version.txt index a9066ea..8201cd9 100644 --- a/fairseq/version.txt +++ b/fairseq/version.txt @@ -1 +1 @@ -2024.07.09-espnet \ No newline at end of file +2024.07.09-espnet diff --git a/setup.py b/setup.py index 056ba45..b908cbe 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ import setuptools - setuptools.setup()