Skip to content

Commit

Permalink
Fixed regression regarding multiple changes in one file.
Browse files Browse the repository at this point in the history
Changed the method of marking changes from a dict keyed by the file name to a list of FileChanges.

FileChanges encapsulate a single change to a file.
  • Loading branch information
coordt committed Dec 15, 2023
1 parent d1d19e3 commit e7a7629
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 51 deletions.
34 changes: 22 additions & 12 deletions bumpversion/config/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Bump My Version configuration models."""
from __future__ import annotations

import logging
import re
from collections import defaultdict
from itertools import chain
from typing import TYPE_CHECKING, Dict, List, MutableMapping, Optional, Tuple, Union

from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from bumpversion.ui import get_indented_logger

if TYPE_CHECKING:
from bumpversion.scm import SCMInfo
from bumpversion.version_part import VersionConfig

logger = logging.getLogger(__name__)
logger = get_indented_logger(__name__)


class VersionPartConfig(BaseModel):
Expand Down Expand Up @@ -48,23 +51,29 @@ def get_search_pattern(self, context: MutableMapping) -> Tuple[re.Pattern, str]:
Returns:
A tuple of the compiled regex pattern and the raw pattern as a string.
"""
logger.debug("Rendering search pattern with context")
logger.indent()
# the default search pattern is escaped, so we can still use it in a regex
raw_pattern = self.search.format(**context)
default = re.compile(re.escape(raw_pattern), re.MULTILINE | re.DOTALL)
if not self.regex:
logger.debug("No RegEx flag detected. Searching for the default pattern: '%s'", default.pattern)
logger.dedent()
return default, raw_pattern

re_context = {key: re.escape(str(value)) for key, value in context.items()}
regex_pattern = self.search.format(**re_context)
try:
search_for_re = re.compile(regex_pattern, re.MULTILINE | re.DOTALL)
logger.debug("Searching for the regex: '%s'", search_for_re.pattern)
logger.dedent()
return search_for_re, raw_pattern
except re.error as e:
logger.error("Invalid regex '%s': %s.", default, e)

logger.debug("Invalid regex. Searching for the default pattern: '%s'", raw_pattern)
logger.dedent()

return default, raw_pattern


Expand Down Expand Up @@ -97,8 +106,6 @@ def add_files(self, filename: Union[str, List[str]]) -> None:
"""Add a filename to the list of files."""
filenames = [filename] if isinstance(filename, str) else filename
for name in filenames:
if name in self.resolved_filemap:
continue
self.files.append(
FileChange(
filename=name,
Expand All @@ -114,29 +121,32 @@ def add_files(self, filename: Union[str, List[str]]) -> None:
)

@property
def resolved_filemap(self) -> Dict[str, FileChange]:
def resolved_filemap(self) -> Dict[str, List[FileChange]]:
"""Return a map of filenames to file configs, expanding any globs."""
from bumpversion.config.utils import resolve_glob_files

output = defaultdict(list)
new_files = []
for file_cfg in self.files:
if file_cfg.glob:
new_files.extend(resolve_glob_files(file_cfg))
else:
new_files.append(file_cfg)

return {file_cfg.filename: file_cfg for file_cfg in new_files}
for file_cfg in new_files:
output[file_cfg.filename].append(file_cfg)
return output

@property
def files_to_modify(self) -> List[FileChange]:
"""Return a list of files to modify."""
files_not_excluded = [
file_cfg.filename
for file_cfg in self.resolved_filemap.values()
if file_cfg.filename not in self.excluded_paths
]
files_not_excluded = [filename for filename in self.resolved_filemap if filename not in self.excluded_paths]
inclusion_set = set(self.included_paths) | set(files_not_excluded)
return [file_cfg for file_cfg in self.resolved_filemap.values() if file_cfg.filename in inclusion_set]
return list(
chain.from_iterable(
file_cfg_list for key, file_cfg_list in self.resolved_filemap.items() if key in inclusion_set
)
)

@property
def version_config(self) -> "VersionConfig":
Expand Down
76 changes: 43 additions & 33 deletions bumpversion/files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Methods for changing files."""
import logging
import re
from copy import deepcopy
from difflib import context_diff
Expand All @@ -8,9 +7,10 @@

from bumpversion.config.models import FileChange, VersionPartConfig
from bumpversion.exceptions import VersionNotFoundError
from bumpversion.ui import get_indented_logger
from bumpversion.version_part import Version, VersionConfig

logger = logging.getLogger(__name__)
logger = get_indented_logger(__name__)


def contains_pattern(search: re.Pattern, contents: str) -> bool:
Expand All @@ -22,7 +22,7 @@ def contains_pattern(search: re.Pattern, contents: str) -> bool:
line_no = contents.count("\n", 0, m.start(0)) + 1
logger.info(
"Found '%s' at line %s: %s",
search,
search.pattern,
line_no,
m.string[m.start() : m.end(0)],
)
Expand All @@ -42,8 +42,11 @@ def log_changes(file_path: str, file_content_before: str, file_content_after: st
"""
if file_content_before != file_content_after:
logger.info("%s file %s:", "Would change" if dry_run else "Changing", file_path)
logger.indent()
indent_str = logger.indent_str

logger.info(
"\n".join(
f"\n{indent_str}".join(
list(
context_diff(
file_content_before.splitlines(),
Expand All @@ -53,8 +56,9 @@ def log_changes(file_path: str, file_content_before: str, file_content_after: st
lineterm="",
)
)
)
),
)
logger.dedent()
else:
logger.info("%s file %s", "Would not change" if dry_run else "Not changing", file_path)

Expand Down Expand Up @@ -104,12 +108,16 @@ def write_file_contents(self, contents: str) -> None:
with open(self.file_change.filename, "wt", encoding="utf-8", newline=self._newlines) as f:
f.write(contents)

def contains_version(self, version: Version, context: MutableMapping) -> bool:
def _contains_change_pattern(
self, search_expression: re.Pattern, raw_search_expression: str, version: Version, context: MutableMapping
) -> bool:
"""
Check whether the version is present in the file.
Does the file contain the change pattern?
Args:
version: The version to check
search_expression: The compiled search expression
raw_search_expression: The raw search expression
version: The version to check, in case it's not the same as the original
context: The context to use
Raises:
Expand All @@ -118,17 +126,15 @@ def contains_version(self, version: Version, context: MutableMapping) -> bool:
Returns:
True if the version number is in fact present.
"""
search_expression, raw_search_expression = self.file_change.get_search_pattern(context)
file_contents = self.get_file_contents()
if contains_pattern(search_expression, file_contents):
return True

# the `search` pattern did not match, but the original supplied
# The `search` pattern did not match, but the original supplied
# version number (representing the same version part values) might
# match instead.
# match instead. This is probably the case if environment variables are used.

# check whether `search` isn't customized, i.e. should match only
# very specific parts of the file
# check whether `search` isn't customized
search_pattern_is_default = self.file_change.search == self.version_config.search

if search_pattern_is_default and contains_pattern(re.compile(re.escape(version.original)), file_contents):
Expand All @@ -141,19 +147,36 @@ def contains_version(self, version: Version, context: MutableMapping) -> bool:
return False
raise VersionNotFoundError(f"Did not find '{raw_search_expression}' in file: '{self.file_change.filename}'")

def replace_version(
def make_file_change(
self, current_version: Version, new_version: Version, context: MutableMapping, dry_run: bool = False
) -> None:
"""Replace the current version with the new version."""
file_content_before = self.get_file_contents()

"""Make the change to the file."""
logger.info(
"\n%sFile %s: replace `%s` with `%s`",
logger.indent_str,
self.file_change.filename,
self.file_change.search,
self.file_change.replace,
)
logger.indent()
logger.debug("Serializing the current version")
logger.indent()
context["current_version"] = self.version_config.serialize(current_version, context)
logger.dedent()
if new_version:
logger.debug("Serializing the new version")
logger.indent()
context["new_version"] = self.version_config.serialize(new_version, context)
logger.dedent()

search_for, raw_search_pattern = self.file_change.get_search_pattern(context)
replace_with = self.version_config.replace.format(**context)

if not self._contains_change_pattern(search_for, raw_search_pattern, current_version, context):
return

file_content_before = self.get_file_contents()

file_content_after = search_for.sub(replace_with, file_content_before)

if file_content_before == file_content_after and current_version.original:
Expand All @@ -163,7 +186,7 @@ def replace_version(
file_content_after = search_for_og.sub(replace_with, file_content_before)

log_changes(self.file_change.filename, file_content_before, file_content_after, dry_run)

logger.dedent()
if not dry_run: # pragma: no-coverage
self.write_file_contents(file_content_after)

Expand Down Expand Up @@ -209,22 +232,9 @@ def modify_files(
context: The context used for rendering the version
dry_run: True if this should be a report-only job
"""
_check_files_contain_version(files, current_version, context)
for f in files:
f.replace_version(current_version, new_version, context, dry_run)


def _check_files_contain_version(
files: List[ConfiguredFile], current_version: Version, context: MutableMapping
) -> None:
"""Make sure files exist and contain version string."""
logger.info(
"Asserting files %s contain the version string...",
", ".join({str(f.file_change.filename) for f in files}),
)
# _check_files_contain_version(files, current_version, context)
for f in files:
context["current_version"] = f.version_config.serialize(current_version, context)
f.contains_version(current_version, context)
f.make_file_change(current_version, new_version, context, dry_run)


class FileUpdater:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_single_file_processed_twice(tmp_path: Path):
assert len(conf.files) == 2
for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(current_version, new_version, ctx)
cfg_file.make_file_change(current_version, new_version, ctx)

assert filepath.read_text() == "dots: 0.10.3\ndashes: 0-10-3"

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_multi_file_configuration(tmp_path: Path):

for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(current_version, major_version, ctx)
cfg_file.make_file_change(current_version, major_version, ctx)

assert full_vers_path.read_text() == "2.0.0"
assert maj_vers_path.read_text() == "2"
Expand All @@ -123,7 +123,7 @@ def test_multi_file_configuration(tmp_path: Path):
major_patch_version = major_version.bump("patch", version_config.order)
for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(major_version, major_patch_version, ctx)
cfg_file.make_file_change(major_version, major_patch_version, ctx)

assert full_vers_path.read_text() == "2.0.1"
assert maj_vers_path.read_text() == "2"
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_search_replace_to_avoid_updating_unconcerned_lines(tmp_path: Path, capl

for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(current_version, new_version, get_context(conf))
cfg_file.make_file_change(current_version, new_version, get_context(conf))

utc_today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
expected_chglog = dedent(
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_simple_replacement_in_utf8_file(tmp_path: Path):
# Act
for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(current_version, new_version, get_context(conf))
cfg_file.make_file_change(current_version, new_version, get_context(conf))

# Assert
out = version_path.read_text()
Expand All @@ -317,7 +317,7 @@ def test_multi_line_search_is_found(tmp_path: Path) -> None:
# Act
for file_cfg in conf.files:
cfg_file = files.ConfiguredFile(file_cfg, version_config)
cfg_file.replace_version(current_version, new_version, get_context(conf))
cfg_file.make_file_change(current_version, new_version, get_context(conf))

# Assert
assert alphabet_path.read_text() == "A\nB\nC\n10.0.0\n"
Expand Down

0 comments on commit e7a7629

Please sign in to comment.