Skip to content

Commit

Permalink
settings: safer writing (#3644)
Browse files Browse the repository at this point in the history
* settings: clean up imports

* settings: try to use atomic rename

* settings: flush, sync and validate new yaml

before replacing the old one

* settings: add test for Settings.save
  • Loading branch information
black-sliver authored Jul 25, 2024
1 parent deae524 commit 8949e21
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
18 changes: 13 additions & 5 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This is different from player options.
"""

import os
import os.path
import shutil
import sys
Expand All @@ -11,7 +12,6 @@
from enum import IntEnum
from threading import Lock
from typing import cast, Any, BinaryIO, ClassVar, Dict, Iterator, List, Optional, TextIO, Tuple, Union, TypeVar
import os

__all__ = [
"get_settings", "fmt_doc", "no_gui",
Expand Down Expand Up @@ -798,6 +798,7 @@ def autosave() -> None:
atexit.register(autosave)

def save(self, location: Optional[str] = None) -> None: # as above
from Utils import parse_yaml
location = location or self._filename
assert location, "No file specified"
temp_location = location + ".tmp" # not using tempfile to test expected file access
Expand All @@ -807,10 +808,18 @@ def save(self, location: Optional[str] = None) -> None: # as above
# can't use utf-8-sig because it breaks backward compat: pyyaml on Windows with bytes does not strip the BOM
with open(temp_location, "w", encoding="utf-8") as f:
self.dump(f)
# replace old with new
if os.path.exists(location):
f.flush()
if hasattr(os, "fsync"):
os.fsync(f.fileno())
# validate new file is valid yaml
with open(temp_location, encoding="utf-8") as f:
parse_yaml(f.read())
# replace old with new, try atomic operation first
try:
os.rename(temp_location, location)
except (OSError, FileExistsError):
os.unlink(location)
os.rename(temp_location, location)
os.rename(temp_location, location)
self._filename = location

def dump(self, f: TextIO, level: int = 0) -> None:
Expand All @@ -832,7 +841,6 @@ def get_settings() -> Settings:
with _lock: # make sure we only have one instance
res = getattr(get_settings, "_cache", None)
if not res:
import os
from Utils import user_path, local_path
filenames = ("options.yaml", "host.yaml")
locations: List[str] = []
Expand Down
29 changes: 27 additions & 2 deletions test/general/test_host_yaml.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import os.path
import unittest
from io import StringIO
from tempfile import TemporaryFile
from tempfile import TemporaryDirectory, TemporaryFile
from typing import Any, Dict, List, cast

import Utils
from settings import Settings, Group
from settings import Group, Settings, ServerOptions


class TestIDs(unittest.TestCase):
Expand Down Expand Up @@ -80,3 +81,27 @@ class AGroup(Group):
self.assertEqual(value_spaces[2], value_spaces[0]) # start of sub-list
self.assertGreater(value_spaces[3], value_spaces[0],
f"{value_lines[3]} should have more indentation than {value_lines[0]} in {lines}")


class TestSettingsSave(unittest.TestCase):
def test_save(self) -> None:
"""Test that saving and updating works"""
with TemporaryDirectory() as d:
filename = os.path.join(d, "host.yaml")
new_release_mode = ServerOptions.ReleaseMode("enabled")
# create default host.yaml
settings = Settings(None)
settings.save(filename)
self.assertTrue(os.path.exists(filename),
"Default settings could not be saved")
self.assertNotEqual(settings.server_options.release_mode, new_release_mode,
"Unexpected default release mode")
# update host.yaml
settings.server_options.release_mode = new_release_mode
settings.save(filename)
self.assertFalse(os.path.exists(filename + ".tmp"),
"Temp file was not removed during save")
# read back host.yaml
settings = Settings(filename)
self.assertEqual(settings.server_options.release_mode, new_release_mode,
"Settings were not overwritten")

0 comments on commit 8949e21

Please sign in to comment.