diff --git a/podman/domain/config.py b/podman/domain/config.py index 173f3724..631f4f57 100644 --- a/podman/domain/config.py +++ b/podman/domain/config.py @@ -4,6 +4,7 @@ import urllib from pathlib import Path from typing import Dict, Optional +import json import xdg.BaseDirectory @@ -48,12 +49,16 @@ def id(self): # pylint: disable=invalid-name @cached_property def url(self): """urllib.parse.ParseResult: Returns URL for service connection.""" - return urllib.parse.urlparse(self.attrs.get("uri")) + if self.attrs.get("uri"): + return urllib.parse.urlparse(self.attrs.get("uri")) + return urllib.parse.urlparse(self.attrs.get("URI")) @cached_property def identity(self): """Path: Returns Path to identity file for service connection.""" - return Path(self.attrs.get("identity")) + if self.attrs.get("identity"): + return Path(self.attrs.get("identity")) + return Path(self.attrs.get("Identity")) class PodmanConfig: @@ -62,17 +67,45 @@ class PodmanConfig: def __init__(self, path: Optional[str] = None): """Read Podman configuration from users XDG_CONFIG_HOME.""" + self.is_default = False if path is None: home = Path(xdg.BaseDirectory.xdg_config_home) - self.path = home / "containers" / "containers.conf" + self.path = home / "containers" / "podman-connections.json" + old_toml_file = home / "containers" / "containers.conf" + self.is_default = True + # this elif is only for testing purposes + elif "@@is_test@@" in path: + test_path = path.replace("@@is_test@@", '') + self.path = Path(test_path) / "podman-connections.json" + old_toml_file = Path(test_path) / "containers.conf" + self.is_default = True else: self.path = Path(path) self.attrs = {} if self.path.exists(): - with self.path.open(encoding='utf-8') as file: + try: + with open(self.path, encoding='utf-8') as file: + self.attrs = json.load(file) + except: # pylint: disable=bare-except + # if the user specifies a path, it can either be a JSON file + # or a TOML file - so try TOML next + try: + with self.path.open(encoding='utf-8') as file: + buffer = file.read() + loaded_toml = toml_loads(buffer) + self.attrs.update(loaded_toml) + except Exception as e: + raise AttributeError( + "The path given is neither a JSON nor a TOML connections file" + ) from e + + # Read the old toml file configuration + if self.is_default and old_toml_file.exists(): + with old_toml_file.open(encoding='utf-8') as file: buffer = file.read() - self.attrs = toml_loads(buffer) + loaded_toml = toml_loads(buffer) + self.attrs.update(loaded_toml) def __hash__(self) -> int: return hash(tuple(self.path.name)) @@ -98,6 +131,7 @@ def services(self): """ services: Dict[str, ServiceConnection] = {} + # read the keys of the toml file first engine = self.attrs.get("engine") if engine: destinations = engine.get("service_destinations") @@ -105,17 +139,35 @@ def services(self): connection = ServiceConnection(key, attrs=destinations[key]) services[key] = connection + # read the keys of the json file next + # this will ensure that if the new json file and the old toml file + # has a connection with the same name defined, we always pick the + # json one + connection = self.attrs.get("Connection") + if connection: + destinations = connection.get("Connections") + for key in destinations: + connection = ServiceConnection(key, attrs=destinations[key]) + services[key] = connection + return services @cached_property def active_service(self): """Optional[ServiceConnection]: Returns active connection.""" + # read the new json file format + connection = self.attrs.get("Connection") + if connection: + active = connection.get("Default") + destinations = connection.get("Connections") + return ServiceConnection(active, attrs=destinations[active]) + + # if we are here, that means there was no default in the new json file engine = self.attrs.get("engine") if engine: active = engine.get("active_service") destinations = engine.get("service_destinations") - for key in destinations: - if key == active: - return ServiceConnection(key, attrs=destinations[key]) + return ServiceConnection(active, attrs=destinations[active]) + return None diff --git a/podman/tests/unit/test_config.py b/podman/tests/unit/test_config.py index 7ecb475a..e2d61d7e 100644 --- a/podman/tests/unit/test_config.py +++ b/podman/tests/unit/test_config.py @@ -1,13 +1,85 @@ import unittest import urllib.parse +import json +import os +import tempfile from pathlib import Path from unittest import mock from unittest.mock import MagicMock - from podman.domain.config import PodmanConfig -class PodmanConfigTestCase(unittest.TestCase): +class PodmanConfigTestCaseDefault(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + + # Data to be written to the JSON file + self.data_json = """ +{ + "Connection": { + "Default": "testing_json", + "Connections": { + "testing_json": { + "URI": "ssh://qe@localhost:2222/run/podman/podman.sock", + "Identity": "/home/qe/.ssh/id_rsa" + }, + "production": { + "URI": "ssh://root@localhost:22/run/podman/podman.sock", + "Identity": "/home/root/.ssh/id_rsajson" + } + } + }, + "Farm": {} +} +""" + + # Data to be written to the TOML file + self.data_toml = """ +[containers] + log_size_max = -1 + pids_limit = 2048 + userns_size = 65536 + +[engine] + num_locks = 2048 + active_service = "testing" + stop_timeout = 10 + [engine.service_destinations] + [engine.service_destinations.production] + uri = "ssh://root@localhost:22/run/podman/podman.sock" + identity = "/home/root/.ssh/id_rsa" + [engine.service_destinations.testing] + uri = "ssh://qe@localhost:2222/run/podman/podman.sock" + identity = "/home/qe/.ssh/id_rsa" + +[network] +""" + + # Define the file path + self.path_json = os.path.join(self.temp_dir, 'podman-connections.json') + self.path_toml = os.path.join(self.temp_dir, 'containers.conf') + + # Write data to the JSON file + j_data = json.loads(self.data_json) + with open(self.path_json, 'w+') as file_json: + json.dump(j_data, file_json) + + # Write data to the TOML file + with open(self.path_toml, 'w+') as file_toml: + # toml.dump(self.data_toml, file_toml) + file_toml.write(self.data_toml) + + def test_connections(self): + config = PodmanConfig("@@is_test@@" + self.temp_dir) + + self.assertEqual(config.active_service.id, "testing_json") + + expected = urllib.parse.urlparse("ssh://qe@localhost:2222/run/podman/podman.sock") + self.assertEqual(config.active_service.url, expected) + self.assertEqual(config.services["production"].identity, Path("/home/root/.ssh/id_rsajson")) + + +class PodmanConfigTestCaseTOML(unittest.TestCase): opener = mock.mock_open( read_data=""" [containers] @@ -35,7 +107,7 @@ def setUp(self) -> None: super().setUp() def mocked_open(self, *args, **kwargs): - return PodmanConfigTestCase.opener(self, *args, **kwargs) + return PodmanConfigTestCaseTOML.opener(self, *args, **kwargs) self.mocked_open = mocked_open @@ -49,10 +121,50 @@ def test_connections(self): self.assertEqual(config.active_service.url, expected) self.assertEqual(config.services["production"].identity, Path("/home/root/.ssh/id_rsa")) - PodmanConfigTestCase.opener.assert_called_with( + PodmanConfigTestCaseTOML.opener.assert_called_with( Path("/home/developer/containers.conf"), encoding='utf-8' ) +class PodmanConfigTestCaseJSON(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + self.temp_dir = tempfile.mkdtemp() + self.data = """ +{ + "Connection": { + "Default": "testing", + "Connections": { + "testing": { + "URI": "ssh://qe@localhost:2222/run/podman/podman.sock", + "Identity": "/home/qe/.ssh/id_rsa" + }, + "production": { + "URI": "ssh://root@localhost:22/run/podman/podman.sock", + "Identity": "/home/root/.ssh/id_rsa" + } + } + }, + "Farm": {} +} +""" + + self.path = os.path.join(self.temp_dir, 'podman-connections.json') + # Write data to the JSON file + data = json.loads(self.data) + with open(self.path, 'w+') as file: + json.dump(data, file) + + def test_connections(self): + config = PodmanConfig(self.path) + + self.assertEqual(config.active_service.id, "testing") + + expected = urllib.parse.urlparse("ssh://qe@localhost:2222/run/podman/podman.sock") + self.assertEqual(config.active_service.url, expected) + self.assertEqual(config.services["production"].identity, Path("/home/root/.ssh/id_rsa")) + + if __name__ == '__main__': unittest.main()