From 3b50fa30e883f68dde6b2e94623a56a065f09e81 Mon Sep 17 00:00:00 2001 From: Fhrozen Date: Wed, 10 Jul 2024 07:06:51 +0900 Subject: [PATCH] fixing pre commit --- fairseq/checkpoint_utils.py | 6 +- fairseq/data/dictionary.py | 10 +-- fairseq/dataclass/configs.py | 20 +++-- fairseq/dataclass/utils.py | 3 +- fairseq/distributed/__init__.py | 7 +- .../fully_sharded_data_parallel.py | 5 +- fairseq/file_chunker_utils.py | 84 +++++++++++++++++++ fairseq/models/fairseq_model.py | 6 +- fairseq/models/hubert/hubert.py | 18 ++-- fairseq/models/wav2vec/wav2vec2.py | 18 ++-- fairseq/modules/__init__.py | 8 +- fairseq/modules/conformer_layer.py | 11 ++- fairseq/modules/espnet_multihead_attention.py | 4 +- fairseq/modules/multihead_attention.py | 3 +- fairseq/search.py | 8 +- fairseq/tasks/fairseq_task.py | 9 +- setup.cfg | 8 ++ 17 files changed, 174 insertions(+), 54 deletions(-) create mode 100644 fairseq/file_chunker_utils.py create mode 100644 setup.cfg diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 33b6f2e..52e501a 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -22,8 +22,10 @@ from fairseq.data import data_utils from fairseq.dataclass.configs import CheckpointConfig -from fairseq.dataclass.utils import (convert_namespace_to_omegaconf, - overwrite_args_by_name) +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index d45dd71..bc585cf 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -9,11 +9,11 @@ import torch -# from fairseq import utils -# from fairseq.data import data_utils -# from fairseq.file_chunker_utils import Chunker, find_offsets -# from fairseq.file_io import PathManager -# from fairseq.tokenizer import tokenize_line +from fairseq import utils +from fairseq.data import data_utils +from fairseq.file_chunker_utils import Chunker, find_offsets +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line class Dictionary: diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index f1194c3..3f3ea69 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -11,15 +11,17 @@ import torch from omegaconf import II, MISSING -from fairseq.dataclass.constants import (DATASET_IMPL_CHOICES, - DDP_BACKEND_CHOICES, - DDP_COMM_HOOK_CHOICES, - GENERATION_CONSTRAINTS_CHOICES, - GENERATION_DECODING_FORMAT_CHOICES, - LOG_FORMAT_CHOICES, - PIPELINE_CHECKPOINT_CHOICES, - PRINT_ALIGNMENT_CHOICES, - ZERO_SHARDING_CHOICES) +from fairseq.dataclass.constants import ( + DATASET_IMPL_CHOICES, + DDP_BACKEND_CHOICES, + DDP_COMM_HOOK_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, + ZERO_SHARDING_CHOICES, +) @dataclass diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f59aa06..13a5561 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -345,8 +345,7 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: no_dc = True if hasattr(args, "arch"): - from fairseq.models import (ARCH_MODEL_NAME_REGISTRY, - ARCH_MODEL_REGISTRY) + from fairseq.models import ARCH_MODEL_NAME_REGISTRY, ARCH_MODEL_REGISTRY if args.arch in ARCH_MODEL_REGISTRY: m_cls = ARCH_MODEL_REGISTRY[args.arch] diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index 0f2feb7..743cc04 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -3,8 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .fully_sharded_data_parallel import (FullyShardedDataParallel, - fsdp_enable_wrap, fsdp_wrap) +from .fully_sharded_data_parallel import ( + FullyShardedDataParallel, + fsdp_enable_wrap, + fsdp_wrap, +) __all__ = [ "fsdp_enable_wrap", diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py index 50eeced..d939654 100644 --- a/fairseq/distributed/fully_sharded_data_parallel.py +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -12,8 +12,9 @@ from fairseq.distributed import utils as dist_utils try: - from fairscale.nn.data_parallel import \ - FullyShardedDataParallel as FSDP # type: ignore + from fairscale.nn.data_parallel import ( + FullyShardedDataParallel as FSDP, # type: ignore + ) has_FSDP = True except ImportError: diff --git a/fairseq/file_chunker_utils.py b/fairseq/file_chunker_utils.py new file mode 100644 index 0000000..3f27549 --- /dev/null +++ b/fairseq/file_chunker_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import typing as tp + + +def _safe_readline(fd) -> str: + pos = fd.tell() + while True: + try: + return fd.readline() + except UnicodeDecodeError: + pos -= 1 + fd.seek(pos) # search where this character begins + + +def find_offsets(filename: str, num_chunks: int) -> tp.List[int]: + """ + given a file and a number of chuncks, find the offsets in the file + to be able to chunk around full lines. + """ + with open(filename, "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + _safe_readline(f) + offsets[i] = f.tell() + offsets[-1] = size + return offsets + + +class ChunkLineIterator: + """ + Iterator to properly iterate over lines of a file chunck. + """ + + def __init__(self, fd, start_offset: int, end_offset: int): + self._fd = fd + self._start_offset = start_offset + self._end_offset = end_offset + + def __iter__(self) -> tp.Iterable[str]: + self._fd.seek(self._start_offset) + # next(f) breaks f.tell(), hence readline() must be used + line = _safe_readline(self._fd) + while line: + pos = self._fd.tell() + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if ( + self._end_offset > 0 + and pos > self._end_offset + and pos < self._end_offset + 2**32 + ): + break + yield line + line = self._fd.readline() + + +class Chunker: + """ + contextmanager to read a chunck of a file line by line. + """ + + def __init__(self, path: str, start_offset: int, end_offset: int): + self.path = path + self.start_offset = start_offset + self.end_offset = end_offset + + def __enter__(self) -> ChunkLineIterator: + self.fd = open(self.path, "r", encoding="utf-8") + return ChunkLineIterator(self.fd, self.start_offset, self.end_offset) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.fd.close() diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 6b606f4..cef7b40 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -18,8 +18,10 @@ from fairseq import utils from fairseq.data import Dictionary -from fairseq.dataclass.utils import (convert_namespace_to_omegaconf, - gen_parser_from_dataclass) +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) from fairseq.models import FairseqDecoder, FairseqEncoder logger = logging.getLogger(__name__) diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py index 958628c..d7910e0 100644 --- a/fairseq/models/hubert/hubert.py +++ b/fairseq/models/hubert/hubert.py @@ -17,14 +17,18 @@ from fairseq.data.dictionary import Dictionary from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import BaseFairseqModel, register_model -from fairseq.models.wav2vec.wav2vec2 import (EXTRACTOR_MODE_CHOICES, - LAYER_TYPE_CHOICES, - MASKING_DISTRIBUTION_CHOICES, - ConvFeatureExtractionModel, - TransformerEncoder) +from fairseq.models.wav2vec.wav2vec2 import ( + EXTRACTOR_MODE_CHOICES, + LAYER_TYPE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder, +) from fairseq.modules import GradMultiply, LayerNorm -from fairseq.tasks.hubert_pretraining import (HubertPretrainingConfig, - HubertPretrainingTask) +from fairseq.tasks.hubert_pretraining import ( + HubertPretrainingConfig, + HubertPretrainingTask, +) logger = logging.getLogger(__name__) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index dad59e5..2327a8d 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -16,13 +16,19 @@ from fairseq.data.data_utils import compute_mask_indices from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.distributed import fsdp_wrap -from fairseq.distributed.fully_sharded_data_parallel import \ - FullyShardedDataParallel +from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel from fairseq.models import BaseFairseqModel, register_model -from fairseq.modules import (Fp32GroupNorm, Fp32LayerNorm, GradMultiply, - GumbelVectorQuantizer, LayerNorm, - MultiheadAttention, RelPositionalEncoding, - SamePad, TransposeLast) +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + GumbelVectorQuantizer, + LayerNorm, + MultiheadAttention, + RelPositionalEncoding, + SamePad, + TransposeLast, +) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer from fairseq.modules.transformer_sentence_encoder import init_bert_params diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py index 0dd7d3f..fbec610 100644 --- a/fairseq/modules/__init__.py +++ b/fairseq/modules/__init__.py @@ -3,9 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .espnet_multihead_attention import (ESPNETMultiHeadedAttention, - RelPositionMultiHeadedAttention, - RotaryPositionMultiHeadedAttention) +from .espnet_multihead_attention import ( + ESPNETMultiHeadedAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) from .fp32_group_norm import Fp32GroupNorm from .grad_multiply import GradMultiply from .gumbel_vector_quantizer import GumbelVectorQuantizer diff --git a/fairseq/modules/conformer_layer.py b/fairseq/modules/conformer_layer.py index 9accbc7..964af24 100644 --- a/fairseq/modules/conformer_layer.py +++ b/fairseq/modules/conformer_layer.py @@ -8,10 +8,13 @@ import torch -from fairseq.modules import (ESPNETMultiHeadedAttention, LayerNorm, - MultiheadAttention, - RelPositionMultiHeadedAttention, - RotaryPositionMultiHeadedAttention) +from fairseq.modules import ( + ESPNETMultiHeadedAttention, + LayerNorm, + MultiheadAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) from fairseq.utils import get_activation_fn diff --git a/fairseq/modules/espnet_multihead_attention.py b/fairseq/modules/espnet_multihead_attention.py index 20f13c3..82bc0d7 100644 --- a/fairseq/modules/espnet_multihead_attention.py +++ b/fairseq/modules/espnet_multihead_attention.py @@ -12,7 +12,9 @@ from torch import nn from fairseq.modules.rotary_positional_embedding import ( - RotaryPositionalEmbedding, apply_rotary_pos_emb) + RotaryPositionalEmbedding, + apply_rotary_pos_emb, +) class ESPNETMultiHeadedAttention(nn.Module): diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index c1bf690..0188095 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -20,8 +20,7 @@ _xformers_available = False from fairseq import utils -from fairseq.models.fairseq_incremental_decoder import \ - FairseqIncrementalDecoder +from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise diff --git a/fairseq/search.py b/fairseq/search.py index 7a63306..25161d6 100644 --- a/fairseq/search.py +++ b/fairseq/search.py @@ -10,9 +10,11 @@ import torch.nn as nn from torch import Tensor -from fairseq.token_generation_constraints import (ConstraintState, - OrderedConstraintState, - UnorderedConstraintState) +from fairseq.token_generation_constraints import ( + ConstraintState, + OrderedConstraintState, + UnorderedConstraintState, +) class Search(nn.Module): diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 7322c9e..24d7778 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -13,8 +13,7 @@ from omegaconf import DictConfig from fairseq import search, tokenizer, utils -from fairseq.data import (Dictionary, FairseqDataset, data_utils, encoders, - iterators) +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.logging import metrics @@ -412,8 +411,10 @@ def build_generator( compute_alignment=getattr(args, "print_alignment", False), ) - from fairseq.sequence_generator import (SequenceGenerator, - SequenceGeneratorWithAlignment) + from fairseq.sequence_generator import ( + SequenceGenerator, + SequenceGeneratorWithAlignment, + ) # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..19bb229 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[flake8] +ignore = H102,H103,W503,H238,E203,H301,H306,E231 +max-line-length = 130 +[pycodestyle] +ignore = H102,H103,W503,H238,E203,H301,H306,E231 +max-line-length = 130 +[isort] +profile = black