Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 9, 2024
1 parent 1945414 commit 59b2bb5
Show file tree
Hide file tree
Showing 36 changed files with 161 additions and 204 deletions.
19 changes: 10 additions & 9 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict

from fairseq.data import data_utils
from fairseq.dataclass.configs import CheckpointConfig
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
overwrite_args_by_name)
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, OmegaConf, open_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,9 +99,9 @@ def is_better(a, b):
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
)
] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints
checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = (
not cfg.no_last_checkpoints
)

extra_state = {
"train_iterator": epoch_itr.state_dict(),
Expand All @@ -116,7 +115,9 @@ def is_better(a, b):
# attributes
if hasattr(trainer.task, "get_checkpoint_dict"):
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
logger.info(
f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint"
)

if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
Expand Down
3 changes: 2 additions & 1 deletion fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from multiprocessing import Pool

import torch

# from fairseq import utils
# from fairseq.data import data_utils
# from fairseq.file_chunker_utils import Chunker, find_offsets
Expand Down Expand Up @@ -358,7 +359,7 @@ def merge_result(counter):
chunks = zip(offsets, offsets[1:])
pool = Pool(processes=num_workers)
results = []
for (start_offset, end_offset) in chunks:
for start_offset, end_offset in chunks:
results.append(
pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
Expand Down
1 change: 0 additions & 1 deletion fairseq/data/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from fairseq import registry


build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
"--tokenizer",
default=None,
Expand Down
2 changes: 2 additions & 0 deletions fairseq/data/fairseq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch.utils.data

from fairseq.data import data_utils

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import numpy as np
import torch
from fairseq.data import data_utils

from fairseq.data import data_utils

logger = logging.getLogger(__name__)

Expand Down
1 change: 0 additions & 1 deletion fairseq/dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .configs import FairseqDataclass
from .constants import ChoiceEnum


__all__ = [
"FairseqDataclass",
"ChoiceEnum",
Expand Down
38 changes: 18 additions & 20 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
import torch
from omegaconf import II, MISSING

from fairseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES,
)
from fairseq.dataclass.constants import (DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES)


@dataclass
Expand Down Expand Up @@ -108,7 +106,7 @@ class CommonConfig(FairseqDataclass):
"help": "log progress every N batches (when progress bar is disabled)"
},
)
log_format: Optional[LOG_FORMAT_CHOICES] = field( # type: ignore
log_format: Optional[LOG_FORMAT_CHOICES] = field( # type: ignore
default=None, metadata={"help": "log format to use"}
)
log_file: Optional[str] = field(
Expand Down Expand Up @@ -298,10 +296,10 @@ class DistributedTrainingConfig(FairseqDataclass):
"help": "do not spawn multiple processes even if multiple GPUs are visible"
},
)
ddp_backend: DDP_BACKEND_CHOICES = field( # type: ignore
ddp_backend: DDP_BACKEND_CHOICES = field( # type: ignore
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( # type: ignore
ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( # type: ignore
default="none", metadata={"help": "communication hook"}
)
bucket_cap_mb: int = field(
Expand Down Expand Up @@ -428,11 +426,11 @@ class DistributedTrainingConfig(FairseqDataclass):
"equal the length of the --pipeline-decoder-balance argument"
},
)
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( # type: ignore
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( # type: ignore
default="never",
metadata={"help": "checkpointing mode for pipeline model parallelism"},
)
zero_sharding: ZERO_SHARDING_CHOICES = field( # type: ignore
zero_sharding: ZERO_SHARDING_CHOICES = field( # type: ignore
default="none", metadata={"help": "ZeRO sharding"}
)
fp16: bool = II("common.fp16")
Expand Down Expand Up @@ -488,7 +486,7 @@ class DatasetConfig(FairseqDataclass):
"help": "maximum sequence length in batch will be a multiplier of this value"
},
)
dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( # type: ignore
dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( # type: ignore
default=None, metadata={"help": "output dataset implementation"}
)
data_buffer_size: int = field(
Expand Down Expand Up @@ -921,7 +919,7 @@ class GenerationConfig(FairseqDataclass):
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
},
)
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( # type: ignore
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( # type: ignore
default=None,
metadata={
"help": "enables lexically constrained decoding",
Expand All @@ -944,7 +942,7 @@ class GenerationConfig(FairseqDataclass):
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( # type: ignore
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( # type: ignore
default=None,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens "
Expand Down Expand Up @@ -1012,7 +1010,7 @@ class GenerationConfig(FairseqDataclass):
},
)
# special decoding format for advanced decoding.
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( # type: ignore
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( # type: ignore
default=None,
metadata={"help": "special decoding format for advanced decoding."},
)
Expand Down
10 changes: 6 additions & 4 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type

from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import FairseqConfig
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict, _utils
from omegaconf import DictConfig, OmegaConf, _utils, open_dict

from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import FairseqConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -344,7 +345,8 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:

no_dc = True
if hasattr(args, "arch"):
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY
from fairseq.models import (ARCH_MODEL_NAME_REGISTRY,
ARCH_MODEL_REGISTRY)

if args.arch in ARCH_MODEL_REGISTRY:
m_cls = ARCH_MODEL_REGISTRY[args.arch]
Expand Down
9 changes: 3 additions & 6 deletions fairseq/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .fully_sharded_data_parallel import (
fsdp_enable_wrap,
fsdp_wrap,
FullyShardedDataParallel,
)
from .fully_sharded_data_parallel import (FullyShardedDataParallel,
fsdp_enable_wrap, fsdp_wrap)

__all__ = [
"fsdp_enable_wrap",
"fsdp_wrap",
"FullyShardedDataParallel",
]
]
9 changes: 5 additions & 4 deletions fairseq/distributed/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import Optional

import torch

from fairseq.dataclass.configs import DistributedTrainingConfig
from fairseq.distributed import utils as dist_utils


try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP # type: ignore
from fairscale.nn.data_parallel import \
FullyShardedDataParallel as FSDP # type: ignore

has_FSDP = True
except ImportError:
Expand Down Expand Up @@ -91,7 +92,7 @@ def size(self) -> int:
@contextlib.contextmanager
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
try:
from fairscale.nn import enable_wrap # type: ignore
from fairscale.nn import enable_wrap # type: ignore
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
Expand Down Expand Up @@ -131,7 +132,7 @@ def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap # type: ignore
from fairscale.nn import wrap # type: ignore

if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
Expand Down
1 change: 0 additions & 1 deletion fairseq/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import shutil
from typing import List, Optional


logger = logging.getLogger(__file__)


Expand Down
3 changes: 1 addition & 2 deletions fairseq/logging/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from .meters import *


# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators = OrderedDict()
Expand Down Expand Up @@ -329,7 +328,7 @@ def load_state_dict(state_dict):

def xla_metrics_report():
try:
import torch_xla.debug.metrics as met # type: ignore
import torch_xla.debug.metrics as met # type: ignore

print(met.metrics_report())
except ImportError:
Expand Down
3 changes: 2 additions & 1 deletion fairseq/models/fairseq_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Dict, List, Optional, Tuple

import torch.nn as nn
from fairseq import utils
from torch import Tensor

from fairseq import utils


class FairseqDecoder(nn.Module):
"""Base class for decoders."""
Expand Down
1 change: 0 additions & 1 deletion fairseq/models/fairseq_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.nn as nn
from torch import Tensor


EncoderOut = NamedTuple(
"EncoderOut",
[
Expand Down
11 changes: 5 additions & 6 deletions fairseq/models/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from torch import Tensor

from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
gen_parser_from_dataclass)
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor

logger = logging.getLogger(__name__)

Expand Down
Loading

0 comments on commit 59b2bb5

Please sign in to comment.