Skip to content

Commit

Permalink
fixing pre commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Fhrozen committed Jul 9, 2024
1 parent e8c0598 commit 3b50fa3
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 54 deletions.
6 changes: 4 additions & 2 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

from fairseq.data import data_utils
from fairseq.dataclass.configs import CheckpointConfig
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
overwrite_args_by_name)
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
Expand Down
10 changes: 5 additions & 5 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import torch

# from fairseq import utils
# from fairseq.data import data_utils
# from fairseq.file_chunker_utils import Chunker, find_offsets
# from fairseq.file_io import PathManager
# from fairseq.tokenizer import tokenize_line
from fairseq import utils
from fairseq.data import data_utils
from fairseq.file_chunker_utils import Chunker, find_offsets
from fairseq.file_io import PathManager
from fairseq.tokenizer import tokenize_line


class Dictionary:
Expand Down
20 changes: 11 additions & 9 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import torch
from omegaconf import II, MISSING

from fairseq.dataclass.constants import (DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES)
from fairseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES,
)


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:

no_dc = True
if hasattr(args, "arch"):
from fairseq.models import (ARCH_MODEL_NAME_REGISTRY,
ARCH_MODEL_REGISTRY)
from fairseq.models import ARCH_MODEL_NAME_REGISTRY, ARCH_MODEL_REGISTRY

if args.arch in ARCH_MODEL_REGISTRY:
m_cls = ARCH_MODEL_REGISTRY[args.arch]
Expand Down
7 changes: 5 additions & 2 deletions fairseq/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .fully_sharded_data_parallel import (FullyShardedDataParallel,
fsdp_enable_wrap, fsdp_wrap)
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
fsdp_enable_wrap,
fsdp_wrap,
)

__all__ = [
"fsdp_enable_wrap",
Expand Down
5 changes: 3 additions & 2 deletions fairseq/distributed/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from fairseq.distributed import utils as dist_utils

try:
from fairscale.nn.data_parallel import \
FullyShardedDataParallel as FSDP # type: ignore
from fairscale.nn.data_parallel import (
FullyShardedDataParallel as FSDP, # type: ignore
)

has_FSDP = True
except ImportError:
Expand Down
84 changes: 84 additions & 0 deletions fairseq/file_chunker_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import typing as tp


def _safe_readline(fd) -> str:
pos = fd.tell()
while True:
try:
return fd.readline()
except UnicodeDecodeError:
pos -= 1
fd.seek(pos) # search where this character begins


def find_offsets(filename: str, num_chunks: int) -> tp.List[int]:
"""
given a file and a number of chuncks, find the offsets in the file
to be able to chunk around full lines.
"""
with open(filename, "r", encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
_safe_readline(f)
offsets[i] = f.tell()
offsets[-1] = size
return offsets


class ChunkLineIterator:
"""
Iterator to properly iterate over lines of a file chunck.
"""

def __init__(self, fd, start_offset: int, end_offset: int):
self._fd = fd
self._start_offset = start_offset
self._end_offset = end_offset

def __iter__(self) -> tp.Iterable[str]:
self._fd.seek(self._start_offset)
# next(f) breaks f.tell(), hence readline() must be used
line = _safe_readline(self._fd)
while line:
pos = self._fd.tell()
# f.tell() does not always give the byte position in the file
# sometimes it skips to a very large number
# it is unlikely that through a normal read we go from
# end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
# that the procedure breaks by the undeterministic behavior of
# f.tell()
if (
self._end_offset > 0
and pos > self._end_offset
and pos < self._end_offset + 2**32
):
break
yield line
line = self._fd.readline()


class Chunker:
"""
contextmanager to read a chunck of a file line by line.
"""

def __init__(self, path: str, start_offset: int, end_offset: int):
self.path = path
self.start_offset = start_offset
self.end_offset = end_offset

def __enter__(self) -> ChunkLineIterator:
self.fd = open(self.path, "r", encoding="utf-8")
return ChunkLineIterator(self.fd, self.start_offset, self.end_offset)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.fd.close()
6 changes: 4 additions & 2 deletions fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
gen_parser_from_dataclass)
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder

logger = logging.getLogger(__name__)
Expand Down
18 changes: 11 additions & 7 deletions fairseq/models/hubert/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
from fairseq.data.dictionary import Dictionary
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.models.wav2vec.wav2vec2 import (EXTRACTOR_MODE_CHOICES,
LAYER_TYPE_CHOICES,
MASKING_DISTRIBUTION_CHOICES,
ConvFeatureExtractionModel,
TransformerEncoder)
from fairseq.models.wav2vec.wav2vec2 import (
EXTRACTOR_MODE_CHOICES,
LAYER_TYPE_CHOICES,
MASKING_DISTRIBUTION_CHOICES,
ConvFeatureExtractionModel,
TransformerEncoder,
)
from fairseq.modules import GradMultiply, LayerNorm
from fairseq.tasks.hubert_pretraining import (HubertPretrainingConfig,
HubertPretrainingTask)
from fairseq.tasks.hubert_pretraining import (
HubertPretrainingConfig,
HubertPretrainingTask,
)

logger = logging.getLogger(__name__)

Expand Down
18 changes: 12 additions & 6 deletions fairseq/models/wav2vec/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@
from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.distributed import fsdp_wrap
from fairseq.distributed.fully_sharded_data_parallel import \
FullyShardedDataParallel
from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel
from fairseq.models import BaseFairseqModel, register_model
from fairseq.modules import (Fp32GroupNorm, Fp32LayerNorm, GradMultiply,
GumbelVectorQuantizer, LayerNorm,
MultiheadAttention, RelPositionalEncoding,
SamePad, TransposeLast)
from fairseq.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GradMultiply,
GumbelVectorQuantizer,
LayerNorm,
MultiheadAttention,
RelPositionalEncoding,
SamePad,
TransposeLast,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer
from fairseq.modules.transformer_sentence_encoder import init_bert_params
Expand Down
8 changes: 5 additions & 3 deletions fairseq/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .espnet_multihead_attention import (ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention)
from .espnet_multihead_attention import (
ESPNETMultiHeadedAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from .fp32_group_norm import Fp32GroupNorm
from .grad_multiply import GradMultiply
from .gumbel_vector_quantizer import GumbelVectorQuantizer
Expand Down
11 changes: 7 additions & 4 deletions fairseq/modules/conformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

import torch

from fairseq.modules import (ESPNETMultiHeadedAttention, LayerNorm,
MultiheadAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention)
from fairseq.modules import (
ESPNETMultiHeadedAttention,
LayerNorm,
MultiheadAttention,
RelPositionMultiHeadedAttention,
RotaryPositionMultiHeadedAttention,
)
from fairseq.utils import get_activation_fn


Expand Down
4 changes: 3 additions & 1 deletion fairseq/modules/espnet_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from torch import nn

from fairseq.modules.rotary_positional_embedding import (
RotaryPositionalEmbedding, apply_rotary_pos_emb)
RotaryPositionalEmbedding,
apply_rotary_pos_emb,
)


class ESPNETMultiHeadedAttention(nn.Module):
Expand Down
3 changes: 1 addition & 2 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
_xformers_available = False

from fairseq import utils
from fairseq.models.fairseq_incremental_decoder import \
FairseqIncrementalDecoder
from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise

Expand Down
8 changes: 5 additions & 3 deletions fairseq/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import torch.nn as nn
from torch import Tensor

from fairseq.token_generation_constraints import (ConstraintState,
OrderedConstraintState,
UnorderedConstraintState)
from fairseq.token_generation_constraints import (
ConstraintState,
OrderedConstraintState,
UnorderedConstraintState,
)


class Search(nn.Module):
Expand Down
9 changes: 5 additions & 4 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from omegaconf import DictConfig

from fairseq import search, tokenizer, utils
from fairseq.data import (Dictionary, FairseqDataset, data_utils, encoders,
iterators)
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.logging import metrics
Expand Down Expand Up @@ -412,8 +411,10 @@ def build_generator(
compute_alignment=getattr(args, "print_alignment", False),
)

from fairseq.sequence_generator import (SequenceGenerator,
SequenceGeneratorWithAlignment)
from fairseq.sequence_generator import (
SequenceGenerator,
SequenceGeneratorWithAlignment,
)

# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
Expand Down
8 changes: 8 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[flake8]
ignore = H102,H103,W503,H238,E203,H301,H306,E231
max-line-length = 130
[pycodestyle]
ignore = H102,H103,W503,H238,E203,H301,H306,E231
max-line-length = 130
[isort]
profile = black

0 comments on commit 3b50fa3

Please sign in to comment.