diff --git a/.cardboardlint.yml b/.cardboardlint.yml deleted file mode 100644 index 4a115a37cd..0000000000 --- a/.cardboardlint.yml +++ /dev/null @@ -1,5 +0,0 @@ -linters: -- pylint: - # pylintrc: pylintrc - filefilter: ['- test_*.py', '+ *.py', '- *.npy'] - # exclude: \ No newline at end of file diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index b7c6393baa..f6018a3f19 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -29,18 +29,7 @@ jobs: architecture: x64 cache: 'pip' cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Style check - run: make style + - name: Install/upgrade dev dependencies + run: python3 -m pip install -r requirements.dev.txt + - name: Lint check + run: make lint diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 911f2a838e..af408ed551 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,27 +1,25 @@ repos: - - repo: 'https://github.com/pre-commit/pre-commit-hooks' - rev: v2.3.0 + - repo: "https://github.com/pre-commit/pre-commit-hooks" + rev: v4.5.0 hooks: - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: 'https://github.com/psf/black' - rev: 22.3.0 + # TODO: enable these later; there are plenty of violating + # files that need to be fixed first + # - id: end-of-file-fixer + # - id: trailing-whitespace + - repo: "https://github.com/psf/black" + rev: 23.12.0 hooks: - id: black language_version: python3 - repo: https://github.com/pycqa/isort - rev: 5.8.0 + rev: 5.13.1 hooks: - - id: isort - name: isort (python) - id: isort name: isort (cython) types: [cython] - - id: isort - name: isort (pyi) - types: [pyi] - - repo: https://github.com/pycqa/pylint - rev: v2.8.2 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.7 hooks: - - id: pylint + - id: ruff + args: [--fix, --exit-non-zero-on-fix] diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 49a9dbdd2c..0000000000 --- a/.pylintrc +++ /dev/null @@ -1,599 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist= - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=missing-docstring, - too-many-public-methods, - too-many-lines, - bare-except, - ## for avoiding weird p3.6 CI linter error - ## TODO: see later if we can remove this - assigning-non-slot, - unsupported-assignment-operation, - ## end - line-too-long, - fixme, - wrong-import-order, - ungrouped-imports, - wrong-import-position, - import-error, - invalid-name, - too-many-instance-attributes, - arguments-differ, - arguments-renamed, - no-name-in-module, - no-member, - unsubscriptable-object, - print-statement, - parameter-unpacking, - unpacking-in-except, - old-raise-syntax, - backtick, - long-suffix, - old-ne-operator, - old-octal-literal, - import-star-module-level, - non-ascii-bytes-literal, - raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - useless-object-inheritance, - too-few-public-methods, - too-many-branches, - too-many-arguments, - too-many-locals, - too-many-statements, - apply-builtin, - basestring-builtin, - buffer-builtin, - cmp-builtin, - coerce-builtin, - execfile-builtin, - file-builtin, - long-builtin, - raw_input-builtin, - reduce-builtin, - standarderror-builtin, - unicode-builtin, - xrange-builtin, - coerce-method, - delslice-method, - getslice-method, - setslice-method, - no-absolute-import, - old-division, - dict-iter-method, - dict-view-method, - next-method-called, - metaclass-assignment, - indexing-exception, - raising-string, - reload-builtin, - oct-method, - hex-method, - nonzero-method, - cmp-method, - input-builtin, - round-builtin, - intern-builtin, - unichr-builtin, - map-builtin-not-iterating, - zip-builtin-not-iterating, - range-builtin-not-iterating, - filter-builtin-not-iterating, - using-cmp-argument, - eq-without-hash, - div-method, - idiv-method, - rdiv-method, - exception-message-attribute, - invalid-str-codec, - sys-max-int, - bad-python3-import, - deprecated-string-function, - deprecated-str-translate-call, - deprecated-itertools-function, - deprecated-types-field, - next-method-defined, - dict-items-not-iterating, - dict-keys-not-iterating, - dict-values-not-iterating, - deprecated-operator-function, - deprecated-urllib-function, - xreadlines-attribute, - deprecated-sys-function, - exception-escape, - comprehension-escape, - duplicate-code, - not-callable, - import-outside-toplevel, - logging-fstring-interpolation, - logging-not-lazy - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[LOGGING] - -# Format style used to check logging format string. `old` means using % -# formatting, while `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package.. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=numpy.*,torch.* - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=120 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -argument-rgx=[a-z_][a-z0-9_]{0,30}$ - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - x, - ex, - Run, - _ - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -variable-rgx=[a-z_][a-z0-9_]{0,30}$ - - -[STRING] - -# This flag controls whether the implicit-str-concat-in-sequence should -# generate a warning on implicit string concatenation in sequences defined over -# several lines. -check-str-concat-over-line-jumps=no - - -[IMPORTS] - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement. -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=15 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ae0ce46048..5fbed84397 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -88,7 +88,7 @@ The following steps are tested on an Ubuntu system. $ make style ``` -10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. +10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. ```bash $ make lint diff --git a/Makefile b/Makefile index 7446848f46..d1a1db8a75 100644 --- a/Makefile +++ b/Makefile @@ -48,8 +48,8 @@ style: ## update code style. black ${target_dirs} isort ${target_dirs} -lint: ## run pylint linter. - pylint ${target_dirs} +lint: ## run linters. + ruff ${target_dirs} black ${target_dirs} --check isort ${target_dirs} --check-only diff --git a/TTS/api.py b/TTS/api.py index 7abc188e74..2ef6f3a085 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -6,10 +6,10 @@ import numpy as np from torch import nn +from TTS.config import load_config from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -from TTS.config import load_config class TTS(nn.Module): @@ -168,9 +168,7 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( - model_name - ) + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) # init synthesizer # None values are fetch from the model @@ -231,7 +229,7 @@ def _check_arguments( raise ValueError("Model is not multi-speaker but `speaker` is provided.") if not self.is_multi_lingual and language is not None: raise ValueError("Model is not multi-lingual but `language` is provided.") - if not emotion is None and not speed is None: + if emotion is not None and speed is not None: raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") def tts( diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py index 662fcd02ec..e76f6a757b 100644 --- a/TTS/bin/collect_env_info.py +++ b/TTS/bin/collect_env_info.py @@ -1,4 +1,5 @@ """Get detailed info about the working environment.""" +import json import os import platform import sys @@ -6,11 +7,10 @@ import numpy import torch -sys.path += [os.path.abspath(".."), os.path.abspath(".")] -import json - import TTS +sys.path += [os.path.abspath(".."), os.path.abspath(".")] + def system_info(): return { diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 9ab520be7d..faadf6901d 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -70,7 +70,7 @@ # if the vocabulary was passed, replace the default if "characters" in C.keys(): - symbols, phonemes = make_symbols(**C.characters) + symbols, phonemes = make_symbols(**C.characters) # noqa: F811 # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 4bd7a78eef..2df0700676 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -13,7 +13,7 @@ def compute_phonemes(item): text = item["text"] ph = phonemizer.phonemize(text).replace("|", "") - return set(list(ph)) + return set(ph) def main(): diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b86252ab67..b06c93f7d1 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -224,7 +224,7 @@ def main(): const=True, default=False, ) - + # args for multi-speaker synthesis parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) @@ -379,10 +379,8 @@ def main(): if model_item["model_type"] == "tts_models": tts_path = model_path tts_config_path = config_path - if "default_vocoder" in model_item: - args.vocoder_name = ( - model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name - ) + if args.vocoder_name is None and "default_vocoder" in model_item: + args.vocoder_name = model_item["default_vocoder"] # voice conversion model if model_item["model_type"] == "voice_conversion_models": diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index c5a6dd68e2..5103f200b0 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -17,9 +17,12 @@ def read_json_with_comments(json_path): with fsspec.open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments but not urls with // - input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) + input_str = re.sub( + r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str + ) return json.loads(input_str) + def register_config(model_name: str) -> Coqpit: """Find the right config for the given model name. diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py index 536faa0108..40e8b8ed32 100644 --- a/TTS/demos/xtts_ft_demo/utils/formatter.py +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -1,23 +1,17 @@ -import os import gc -import torchaudio +import os + import pandas +import torch +import torchaudio from faster_whisper import WhisperModel -from glob import glob - from tqdm import tqdm -import torch -import torchaudio # torch.set_num_threads(1) - from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners torch.set_num_threads(16) - -import os - audio_types = (".wav", ".mp3", ".flac") @@ -25,9 +19,10 @@ def list_audios(basePath, contains=None): # return the set of files that are valid return list_files(basePath, validExts=audio_types, contains=contains) + def list_files(basePath, validExts=None, contains=None): # loop over the directory structure - for (rootDir, dirNames, filenames) in os.walk(basePath): + for rootDir, dirNames, filenames in os.walk(basePath): # loop over the filenames in the current directory for filename in filenames: # if the contains string is not none and the filename does not contain @@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None): continue # determine the file extension of the current file - ext = filename[filename.rfind("."):].lower() + ext = filename[filename.rfind(".") :].lower() # check to see if the file is an audio and should be processed if validExts is None or ext.endswith(validExts): @@ -44,13 +39,22 @@ def list_files(basePath, validExts=None, contains=None): audioPath = os.path.join(rootDir, filename) yield audioPath -def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + +def format_audio_list( + audio_files, + target_language="en", + out_path=None, + buffer=0.2, + eval_percentage=0.15, + speaker_name="coqui", + gradio_progress=None, +): audio_total_size = 0 # make sure that ooutput file exists os.makedirs(out_path, exist_ok=True) # Loading Whisper - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading Whisper Model!") asr_model = WhisperModel("large-v2", device=device, compute_type="float16") @@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 wav = torch.mean(wav, dim=0, keepdim=True) wav = wav.squeeze() - audio_total_size += (wav.size(-1) / sr) + audio_total_size += wav.size(-1) / sr segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments = list(segments) @@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # get previous sentence end previous_word_end = words_list[word_idx - 1].end # add buffer or get the silence midle between the previous sentence and the current one - sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2) sentence = word.word first_word = False @@ -118,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # Average the current word end and next word start word_end = min((word.end + next_word_start) / 2, word.end + buffer) - + absoulte_path = os.path.join(out_path, audio_file) os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) i += 1 first_word = True - audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0) # if the audio is too short ignore it (i.e < 0.33 seconds) - if audio.size(-1) >= sr/3: - torchaudio.save(absoulte_path, - audio, - sr - ) + if audio.size(-1) >= sr / 3: + torchaudio.save(absoulte_path, audio, sr) else: continue @@ -140,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 df = pandas.DataFrame(metadata) df = df.sample(frac=1) - num_val_samples = int(len(df)*eval_percentage) + num_val_samples = int(len(df) * eval_percentage) df_eval = df[:num_val_samples] df_train = df[num_val_samples:] - df_train = df_train.sort_values('audio_file') + df_train = df_train.sort_values("audio_file") train_metadata_path = os.path.join(out_path, "metadata_train.csv") df_train.to_csv(train_metadata_path, sep="|", index=False) eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") - df_eval = df_eval.sort_values('audio_file') + df_eval = df_eval.sort_values("audio_file") df_eval.to_csv(eval_metadata_path, sep="|", index=False) # deallocate VRAM and RAM del asr_model, df_train, df_eval, df, metadata gc.collect() - return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file + return train_metadata_path, eval_metadata_path, audio_total_size diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index a98765c3e7..7b41966b8f 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -1,5 +1,5 @@ -import os import gc +import os from trainer import Trainer, TrainerArgs @@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, BATCH_SIZE = batch_size # set here the batch size GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps - # Define here the dataset that you want to use for the fine-tuning on. config_dataset = BaseDatasetConfig( formatter="coqui", @@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) - # DVAE files DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" @@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # download DVAE files if needed if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): print(" > Downloading DVAE files!") - ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) - + ModelManager._download_model_files( + [MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) # Download XTTS v2.0 checkpoint if needed TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" @@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # get the longest text audio file to use as speaker reference samples_len = [len(item["text"].split(" ")) for item in train_samples] - longest_text_idx = samples_len.index(max(samples_len)) + longest_text_idx = samples_len.index(max(samples_len)) speaker_ref = train_samples[longest_text_idx]["audio_file"] trainer_out_path = trainer.output_path diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index ebb11f29d1..85168c641d 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -1,19 +1,16 @@ import argparse +import logging import os import sys import tempfile +import traceback import gradio as gr -import librosa.display -import numpy as np - -import os import torch import torchaudio -import traceback + from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt - from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts @@ -23,7 +20,10 @@ def clear_gpu_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() + XTTS_MODEL = None + + def load_model(xtts_checkpoint, xtts_config, xtts_vocab): global XTTS_MODEL clear_gpu_cache() @@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): print("Model Loaded!") return "Model Loaded!" + def run_tts(lang, tts_text, speaker_audio_file): if XTTS_MODEL is None or not speaker_audio_file: return "You need to run the previous step to load the model !!", None, None - gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( + audio_path=speaker_audio_file, + gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, + max_ref_length=XTTS_MODEL.config.max_ref_len, + sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, + ) out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, - temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here length_penalty=XTTS_MODEL.config.length_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty, top_k=XTTS_MODEL.config.top_k, @@ -65,9 +71,7 @@ def run_tts(lang, tts_text, speaker_audio_file): return "Speech generated !", out_path, speaker_audio_file - - -# define a logger to redirect +# define a logger to redirect class Logger: def __init__(self, filename="log.out"): self.log_file = filename @@ -85,21 +89,19 @@ def flush(self): def isatty(self): return False + # redirect stdout and stderr to a file sys.stdout = Logger() sys.stderr = sys.stdout # logging.basicConfig(stream=sys.stdout, level=logging.INFO) -import logging + logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.StreamHandler(sys.stdout) - ] + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) + def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: @@ -107,12 +109,11 @@ def read_logs(): if __name__ == "__main__": - parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ Example runs: - python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port """, formatter_class=argparse.RawTextHelpFormatter, ) @@ -190,12 +191,10 @@ def read_logs(): "zh", "hu", "ko", - "ja" + "ja", ], ) - progress_data = gr.Label( - label="Progress:" - ) + progress_data = gr.Label(label="Progress:") logs = gr.Textbox( label="Logs:", interactive=False, @@ -203,20 +202,30 @@ def read_logs(): demo.load(read_logs, None, logs, every=1) prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") - + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: - return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + return ( + "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", + "", + "", + ) else: try: - train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + train_meta, eval_meta, audio_total_size = format_audio_list( + audio_path, target_language=language, out_path=out_path, gradio_progress=progress + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + return ( + f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", + "", + "", + ) clear_gpu_cache() @@ -236,7 +245,7 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac eval_csv = gr.Textbox( label="Eval CSV:", ) - num_epochs = gr.Slider( + num_epochs = gr.Slider( label="Number of epochs:", minimum=1, maximum=100, @@ -264,9 +273,7 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac step=1, value=args.max_audio_length, ) - progress_train = gr.Label( - label="Progress:" - ) + progress_train = gr.Label(label="Progress:") logs_tts_train = gr.Textbox( label="Logs:", interactive=False, @@ -274,18 +281,41 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + def train_model( + language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length + ): clear_gpu_cache() if not train_csv or not eval_csv: - return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + return ( + "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", + "", + "", + "", + "", + ) try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt( + language, + num_epochs, + batch_size, + grad_acumm, + train_csv, + eval_csv, + output_path=output_path, + max_audio_length=max_audio_length, + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + return ( + f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", + "", + "", + "", + "", + ) # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") @@ -312,9 +342,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum label="XTTS vocab path:", value="", ) - progress_load = gr.Label( - label="Progress:" - ) + progress_load = gr.Label(label="Progress:") load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") with gr.Column() as col2: @@ -342,7 +370,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum "hu", "ko", "ja", - ] + ], ) tts_text = gr.Textbox( label="Input Text.", @@ -351,9 +379,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: - progress_gen = gr.Label( - label="Progress:" - ) + progress_gen = gr.Label(label="Progress:") tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") @@ -371,7 +397,6 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum ], ) - train_btn.click( fn=train_model, inputs=[ @@ -386,14 +411,10 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) - + load_btn.click( fn=load_model, - inputs=[ - xtts_checkpoint, - xtts_config, - xtts_vocab - ], + inputs=[xtts_checkpoint, xtts_config, xtts_vocab], outputs=[progress_load], ) @@ -407,9 +428,4 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum outputs=[progress_gen, tts_output_audio, reference_audio], ) - demo.launch( - share=True, - debug=False, - server_port=args.port, - server_name="0.0.0.0" - ) + demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0") diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py index 5eda2671be..1d12325cf2 100644 --- a/TTS/encoder/configs/emotion_encoder_config.py +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass +from dataclasses import dataclass from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py index 6dceb00277..0588527a68 100644 --- a/TTS/encoder/configs/speaker_encoder_config.py +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass +from dataclasses import dataclass from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 236d6fe937..88ed71d3f4 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -34,7 +34,7 @@ def __init__(self, ap, augmentation_config): # ignore not listed directories if noise_dir not in self.additive_noise_types: continue - if not noise_dir in self.noise_list: + if noise_dir not in self.noise_list: self.noise_list[noise_dir] = [] self.noise_list[noise_dir].append(wav_file) diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index a6a3b9aeb1..9e487b1e9d 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -7,8 +7,6 @@ # Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py -import logging -from pathlib import Path import torch from einops import pack, unpack diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index c906b882e5..74ec204281 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -362,7 +362,7 @@ def forward( pos_encoding = positional_encoding( self.emb_dim, - max(token_embeddings.shape[1], max(mel_lens)), + max(token_embeddings.shape[1], *mel_lens), device=token_embeddings.device, ) encoder_outputs = self.encoder( diff --git a/TTS/tts/layers/overflow/plotting_utils.py b/TTS/tts/layers/overflow/plotting_utils.py index a63aeb370a..d9d3e3d141 100644 --- a/TTS/tts/layers/overflow/plotting_utils.py +++ b/TTS/tts/layers/overflow/plotting_utils.py @@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out ax.set_title("Transition probability of state") ax.set_xlabel("hidden state") ax.set_ylabel("probability") - ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension + ax.set_xticks(list(range(len(transition_probabilities)))) ax.set_xticklabels([int(x) for x in states], rotation=90) plt.tight_layout() if not output_fig: diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index dad1814369..c79ef31b0c 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -1,6 +1,5 @@ import functools import math -import os import fsspec import torch diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py index 69b8c17c3f..241dfdd4f4 100644 --- a/TTS/tts/layers/tortoise/clvp.py +++ b/TTS/tts/layers/tortoise/clvp.py @@ -126,7 +126,7 @@ def forward(self, text, speech_tokens, return_loss=False): text_latents = self.to_text_latent(text_latents) speech_latents = self.to_speech_latent(speech_latents) - text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents)) temp = self.temperature.exp() diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index 7bea02ca08..2b29091b44 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -972,7 +972,7 @@ def autoregressive_training_losses( assert False # not currently supported for this type of diffusion. elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) - terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + terms.update(dict(zip(model_output_keys, model_outputs))) model_output = terms[gd_out_key] if self.model_var_type in [ ModelVarType.LEARNED, diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py index 70d46aa3e0..6cb1bab96a 100644 --- a/TTS/tts/layers/tortoise/transformer.py +++ b/TTS/tts/layers/tortoise/transformer.py @@ -37,7 +37,7 @@ def route_args(router, args, depth): for key in matched_keys: val = args[key] for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): - new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + new_f_args, new_g_args = (({key: val} if route else {}) for route in routes) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) return routed_args @@ -152,7 +152,7 @@ def forward(self, x, mask=None): softmax = torch.softmax qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv) q = q * self.scale diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py index 1eb3f77269..9325b8c720 100644 --- a/TTS/tts/layers/tortoise/xtransformers.py +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -84,7 +84,7 @@ def init_zero_(layer): def pick_and_pop(keys, d): - values = list(map(lambda key: d.pop(key), keys)) + values = [d.pop(key) for key in keys] return dict(zip(keys, values)) @@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())} return kwargs_without_prefix, kwargs @@ -428,7 +428,7 @@ def forward(self, x, **kwargs): feats_per_shift = x.shape[-1] // segments splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)] x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) @@ -635,7 +635,7 @@ def forward( v = self.to_v(v_input) if not collab_heads: - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v)) else: q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) k = rearrange(k, "b n d -> b () n d") @@ -650,9 +650,9 @@ def forward( if exists(rotary_pos_emb) and not has_context: l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) - ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + (ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v)) + ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl)) + q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr))) input_mask = None if any(map(exists, (mask, context_mask))): @@ -664,7 +664,7 @@ def forward( input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): @@ -964,9 +964,7 @@ def forward( seq_len = x.shape[1] if past_key_values is not None: seq_len += past_key_values[0][0].shape[-2] - max_rotary_emb_length = max( - list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len] - ) + max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len]) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) present_key_values = [] @@ -1200,7 +1198,7 @@ def forward( res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) @@ -1249,7 +1247,7 @@ def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index c27d11bef6..3449739fdc 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP class DiscriminatorS(torch.nn.Module): diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index bdd7a9d09f..8598f0b47a 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -260,7 +260,7 @@ def __init__( dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] - enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans)) pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): @@ -306,9 +306,9 @@ def norm(self, images): if not self.normalization is not None: return images - means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + means, stds = (torch.as_tensor(t).to(images) for t in self.normalization) arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" - means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) + means, stds = (rearrange(t, arrange) for t in (means, stds)) images = images.clone() images.sub_(means).div_(stds) return images diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index e7b186b858..ca0dc7cc74 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -1,7 +1,6 @@ # ported from: https://github.com/neonbjb/tortoise-tts import functools -import math import random import torch diff --git a/TTS/tts/layers/xtts/gpt_inference.py b/TTS/tts/layers/xtts/gpt_inference.py index d44bd3decd..4625ae1ba9 100644 --- a/TTS/tts/layers/xtts/gpt_inference.py +++ b/TTS/tts/layers/xtts/gpt_inference.py @@ -1,5 +1,3 @@ -import math - import torch from torch import nn from transformers import GPT2PreTrainedModel diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index 7b7ee79b50..d1aa16c456 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -155,10 +155,6 @@ def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) -def exists(x): - return x is not None - - def default(val, d): if exists(val): return val diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index e12f8995cf..b7e07589c5 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -43,7 +43,7 @@ def __init__(self, **kwargs): class NewGenerationMixin(GenerationMixin): @torch.no_grad() - def generate( + def generate( # noqa: PLR0911 self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[StreamGenerationConfig] = None, @@ -885,10 +885,10 @@ def init_stream_support(): if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + init_stream_support() - PreTrainedModel.generate = NewGenerationMixin.generate - PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 2f958cb5a5..4d6d6ede6e 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -1,4 +1,3 @@ -import os import random import sys diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 9a7a1d7783..daf9fc7e4f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -5,7 +5,6 @@ import torch.nn as nn import torchaudio from coqpit import Coqpit -from torch.nn import functional as F from torch.utils.data import DataLoader from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler @@ -391,7 +390,7 @@ def get_data_loader( loader = DataLoader( dataset, sampler=sampler, - batch_size = config.eval_batch_size if is_eval else config.batch_size, + batch_size=config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, diff --git a/TTS/tts/layers/xtts/xtts_manager.py b/TTS/tts/layers/xtts/xtts_manager.py index 3e7d0f6c91..5a28d2a8a6 100644 --- a/TTS/tts/layers/xtts/xtts_manager.py +++ b/TTS/tts/layers/xtts/xtts_manager.py @@ -1,34 +1,35 @@ import torch -class SpeakerManager(): + +class SpeakerManager: def __init__(self, speaker_file_path=None): self.speakers = torch.load(speaker_file_path) @property def name_to_id(self): return self.speakers.keys() - + @property def num_speakers(self): return len(self.name_to_id) - + @property def speaker_names(self): return list(self.name_to_id.keys()) - -class LanguageManager(): + +class LanguageManager: def __init__(self, config): self.langs = config["languages"] @property def name_to_id(self): return self.langs - + @property def num_languages(self): return len(self.name_to_id) - + @property def language_names(self): return list(self.name_to_id) diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index e59ccb6630..7d8f658160 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -4,13 +4,11 @@ import argparse import csv -import os import re import string import sys # fmt: off - # ================================================================================ # # basic constant # ================================================================================ # @@ -491,8 +489,6 @@ class NumberSystem(object): 中文数字系统 """ - pass - class MathSymbol(object): """ diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index b2e51de7d6..18b9cde385 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -415,7 +415,7 @@ def _set_phase(config, global_step): """Decide AlignTTS training phase""" if isinstance(config.phase_start_steps, list): vals = [i < global_step for i in config.phase_start_steps] - if not True in vals: + if True not in vals: phase = 0 else: phase = ( diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7871cc38c3..be76f6c2d3 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,7 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b6e9ac8a14..1d3a13d433 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -299,7 +299,7 @@ def init_multispeaker(self, config: Coqpit): if config.use_d_vector_file: self.embedded_speaker_dim = config.d_vector_dim if self.args.d_vector_dim != self.args.hidden_channels: - #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + # self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: @@ -404,13 +404,13 @@ def _forward_encoder( # [B, T, C] x_emb = self.emb(x) # encoder pass - #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + # o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) # speaker conditioning # TODO: try different ways of conditioning - if g is not None: + if g is not None: if hasattr(self, "proj_g"): - g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) o_en = o_en + g return o_en, x_mask, g, x_emb diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d9b1f59618..2c60ece789 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1880,7 +1880,7 @@ def onnx_inference(text, text_lengths, scales, sid=None, langid=None): self.forward = _forward if training: self.train() - if not disc is None: + if disc is not None: self.disc = disc def load_onnx(self, model_path: str, cuda=False): @@ -1914,9 +1914,9 @@ def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): dtype=np.float32, ) input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} - if not speaker_id is None: + if speaker_id is not None: input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() - if not language_id is None: + if language_id is not None: input_params["langid"] = torch.tensor([language_id]).cpu().numpy() audio = self.onnx_sess.run( @@ -1948,8 +1948,7 @@ def __init__( def _create_vocab(self): self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - # pylint: disable=unnecessary-comprehension - self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + self._id_to_char = dict(enumerate(self.vocab)) @staticmethod def init_from_config(config: Coqpit): @@ -1996,4 +1995,4 @@ def vocab(self, vocab_file): self.blank = self._vocab[0] self.pad = " " self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension - self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = dict(enumerate(self._vocab)) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 83812f377f..a6e9aefa5d 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence -from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager +from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -274,7 +274,7 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = for i in range(0, audio.shape[1], 22050 * chunk_length): audio_chunk = audio[:, i : i + 22050 * chunk_length] - # if the chunk is too short ignore it + # if the chunk is too short ignore it if audio_chunk.size(-1) < 22050 * 0.33: continue @@ -410,12 +410,14 @@ def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwa if speaker_id is not None: gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) - settings.update({ - "gpt_cond_len": config.gpt_cond_len, - "gpt_cond_chunk_len": config.gpt_cond_chunk_len, - "max_ref_len": config.max_ref_len, - "sound_norm_refs": config.sound_norm_refs, - }) + settings.update( + { + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + } + ) return self.full_inference(text, speaker_wav, language, **settings) @torch.inference_mode() diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 1e1836b32c..89e5e1911e 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -59,7 +59,7 @@ def parse_language_ids_from_config(c: Coqpit) -> Dict: languages.add(dataset["language"]) else: raise ValueError(f"Dataset {dataset['name']} has no language specified.") - return {name: i for i, name in enumerate(sorted(list(languages)))} + return {name: i for i, name in enumerate(sorted(languages))} def set_language_ids_from_config(self, c: Coqpit) -> None: """Set language IDs from config samples. diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 1f94c5332d..23aa52a8a2 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -193,7 +193,7 @@ def read_embeddings_from_file(file_path: str): embeddings = load_file(file_path) speakers = sorted({x["name"] for x in embeddings.values()}) name_to_id = {name: i for i, name in enumerate(speakers)} - clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys()))) + clip_ids = list(set(clip_name for clip_name in embeddings.keys())) # cache embeddings_by_names for fast inference using a bigger speakers.json embeddings_by_names = {} for x in embeddings.values(): diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 8fa45ed84b..37c7a7ca23 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -87,9 +87,7 @@ def vocab(self, vocab): if vocab is not None: self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension - } + self._id_to_char = dict(enumerate(self._vocab)) @staticmethod def init_from_config(config, **kwargs): @@ -269,9 +267,7 @@ def vocab(self): def vocab(self, vocab): self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension - } + self._id_to_char = dict(enumerate(self.vocab)) @property def num_chars(self): diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py index c3111067e1..30072ae501 100644 --- a/TTS/tts/utils/text/japanese/phonemizer.py +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -350,8 +350,8 @@ def hira2kata(text: str) -> str: return text.replace("う゛", "ヴ") -_SYMBOL_TOKENS = set(list("・、。?!")) -_NO_YOMI_TOKENS = set(list("「」『』―()[][] …")) +_SYMBOL_TOKENS = set("・、。?!") +_NO_YOMI_TOKENS = set("「」『』―()[][] …") _TAGGER = MeCab.Tagger() diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index f9a0340c55..744ccb3e70 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -10,7 +10,6 @@ from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer except ImportError: JA_JP_Phonemizer = None - pass PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)} diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3a527f4609..d724cc87ec 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -516,7 +516,7 @@ def _update_path(field_name, new_path, config_path): sub_conf[field_names[-1]] = new_path else: # field name points to a top-level field - if not field_name in config: + if field_name not in config: return if isinstance(config[field_name], list): config[field_name] = [new_path] diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index b98647c30c..6165fb5e8a 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -335,7 +335,7 @@ def tts( # handle multi-lingual language_id = None if self.tts_languages_file or ( - hasattr(self.tts_model, "language_manager") + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None and not self.tts_config.model == "xtts" ): diff --git a/TTS/vc/configs/shared_configs.py b/TTS/vc/configs/shared_configs.py index 74164a7444..b2fe63d29d 100644 --- a/TTS/vc/configs/shared_configs.py +++ b/TTS/vc/configs/shared_configs.py @@ -1,7 +1,5 @@ -from dataclasses import asdict, dataclass, field -from typing import Dict, List - -from coqpit import Coqpit, check_argument +from dataclasses import dataclass, field +from typing import List from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 8bb9989224..a5a340f2aa 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -164,7 +164,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), @@ -201,7 +201,7 @@ def forward(self, x): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ norm_f(Conv1d(1, 16, 15, 1, padding=7)), @@ -468,7 +468,7 @@ def inference(self, c, g=None, mel=None, c_lengths=None): Returns: torch.Tensor: Output tensor. """ - if c_lengths == None: + if c_lengths is None: c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) if not self.use_spk: g = self.enc_spk.embed_utterance(mel) diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index e799cc2a5b..e5fb13c11c 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -1,8 +1,6 @@ import math -import numpy as np import torch -from torch import nn from torch.nn import functional as F diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/modules/freevc/speaker_encoder/audio.py index 52f6fd0893..5b23a4dbb6 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/audio.py +++ b/TTS/vc/modules/freevc/speaker_encoder/audio.py @@ -1,13 +1,17 @@ -import struct from pathlib import Path from typing import Optional, Union # import webrtcvad import librosa import numpy as np -from scipy.ndimage.morphology import binary_dilation -from TTS.vc.modules.freevc.speaker_encoder.hparams import * +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + audio_norm_target_dBFS, + mel_n_channels, + mel_window_length, + mel_window_step, + sampling_rate, +) int16_max = (2**15) - 1 diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py index 2e21a14fd8..7f811ac3ab 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -1,4 +1,3 @@ -from pathlib import Path from time import perf_counter as timer from typing import List, Union @@ -8,7 +7,15 @@ from TTS.utils.io import load_fsspec from TTS.vc.modules.freevc.speaker_encoder import audio -from TTS.vc.modules.freevc.speaker_encoder.hparams import * +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + mel_n_channels, + mel_window_step, + model_embedding_size, + model_hidden_size, + model_num_layers, + partials_n_frames, + sampling_rate, +) class SpeakerEncoder(nn.Module): diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/modules/freevc/wavlm/wavlm.py index fc93bd4f50..d2f28d19c2 100644 --- a/TTS/vc/modules/freevc/wavlm/wavlm.py +++ b/TTS/vc/modules/freevc/wavlm/wavlm.py @@ -387,7 +387,7 @@ def make_conv(): nn.init.kaiming_normal_(conv.weight) return conv - assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 74cfc7262b..1f977755cc 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -298,7 +298,7 @@ def forward( adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss - if self.use_feat_match_loss and not feats_fake is None: + if self.use_feat_match_loss and feats_fake is not None: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) return_dict["G_feat_match_loss"] = feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 63a0af4445..113240fd75 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -40,7 +40,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_ Returns: Dict: output figures keyed by the name of the figures. - """ """Plot vocoder model results""" + """ if name_prefix is None: name_prefix = "" diff --git a/pyproject.toml b/pyproject.toml index 922575305c..934e0c2ebd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,59 @@ requires = [ "packaging", ] -[flake8] -max-line-length=120 +[tool.ruff] +line-length = 120 +extend-select = [ + "B033", # duplicate-value + "C416", # unnecessary-comprehension + "D419", # empty-docstring + "E999", # syntax-error + "F401", # unused-import + "F704", # yield-outside-function + "F706", # return-outside-function + "F841", # unused-variable + "I", # import sorting + "PIE790", # unnecessary-pass + "PLC", + "PLE", + "PLR0124", # comparison-with-itself + "PLR0206", # property-with-parameters + "PLR0911", # too-many-return-statements + "PLR1711", # useless-return + "PLW", + "W291", # trailing-whitespace +] + +ignore = [ + "E501", # line too long + "E722", # bare except (TODO: fix these) + "E731", # don't use lambdas + "E741", # ambiguous variable name + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments + "PLR0915", # too-many-statements + "UP004", # useless-object-inheritance + "F821", # TODO: enable + "F841", # TODO: enable + "PLW0602", # TODO: enable + "PLW2901", # TODO: enable + "PLW0127", # TODO: enable + "PLW0603", # TODO: enable +] + +[tool.ruff.pylint] +max-args = 5 +max-public-methods = 20 +max-returns = 7 + +[tool.ruff.per-file-ignores] +"**/__init__.py" = [ + "F401", # init files may have "unused" imports for now + "F403", # init files may have star imports for now +] +"hubconf.py" = [ + "E402", # module level import not at top of file +] [tool.black] line-length = 120 diff --git a/recipes/bel-alex73/train_hifigan.py b/recipes/bel-alex73/train_hifigan.py index 3e740b2ff4..78221a9f2b 100644 --- a/recipes/bel-alex73/train_hifigan.py +++ b/recipes/bel-alex73/train_hifigan.py @@ -1,11 +1,8 @@ -import os - -from coqpit import Coqpit from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseAudioConfig from TTS.utils.audio import AudioProcessor -from TTS.vocoder.configs.hifigan_config import * +from TTS.vocoder.configs.hifigan_config import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.gan import GAN diff --git a/recipes/multilingual/cml_yourtts/train_yourtts.py b/recipes/multilingual/cml_yourtts/train_yourtts.py index 25a2fd0a4b..02f901fe73 100644 --- a/recipes/multilingual/cml_yourtts/train_yourtts.py +++ b/recipes/multilingual/cml_yourtts/train_yourtts.py @@ -4,7 +4,6 @@ from trainer import Trainer, TrainerArgs from TTS.bin.compute_embeddings import compute_embeddings -from TTS.bin.resample import resample_files from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples diff --git a/requirements.dev.txt b/requirements.dev.txt index 8c674727d3..21c4c3d21e 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,4 +2,4 @@ black coverage isort nose2 -pylint==2.10.2 +ruff==0.1.3 diff --git a/setup.py b/setup.py index df14b41adc..b01b655877 100644 --- a/setup.py +++ b/setup.py @@ -23,12 +23,12 @@ import os import subprocess import sys -from packaging.version import Version import numpy import setuptools.command.build_py import setuptools.command.develop from Cython.Build import cythonize +from packaging.version import Version from setuptools import Extension, find_packages, setup python_version = sys.version.split()[0] diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index cbd98fc0c5..172ee7cef3 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -8,7 +8,8 @@ from tests import get_tests_data_path, get_tests_output_path from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig -from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.datasets import load_tts_samples +from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index b1bdeb9fd1..72b6bcd46b 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -278,7 +278,7 @@ def test_train_step(): }, ) - batch = dict({}) + batch = {} batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 906ec3d09f..2ca068f6fe 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -266,7 +266,7 @@ def test_train_step(): }, ) - batch = dict({}) + batch = {} batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index a4a4f72679..3755ab3f06 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -4,8 +4,7 @@ import torch from tests import get_tests_input_path -from TTS.vc.configs.freevc_config import FreeVCConfig -from TTS.vc.models.freevc import FreeVC +from TTS.vc.models.freevc import FreeVC, FreeVCConfig # pylint: disable=unused-variable # pylint: disable=no-self-use