Skip to content

Commit 5702be3

Browse files
Update config loader
1 parent 77f0462 commit 5702be3

3 files changed

Lines changed: 256 additions & 67 deletions

File tree

common_utility/configLoader.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,108 @@
22
# SPDX-FileCopyrightText: 2024 Attila Gombos <attila.gombos@effective-range.com>
33
# SPDX-License-Identifier: MIT
44

5-
import os
5+
import sys
6+
from argparse import ArgumentParser, Action, Namespace
67
from configparser import ConfigParser
78
from pathlib import Path
8-
from typing import Any
9+
from typing import Any, cast
910

1011
from context_logger import get_logger
1112

12-
from common_utility import copy_file
13-
14-
log = get_logger('ConfigLoader')
15-
1613

1714
class IConfigLoader(object):
1815

19-
def load(self, arguments: dict[str, Any]) -> dict[str, Any]:
16+
def load(self, argument_parser: ArgumentParser) -> Namespace:
2017
raise NotImplementedError()
2118

2219

2320
class ConfigLoader(IConfigLoader):
2421

25-
def __init__(self, default_config_file: Path, config_file_argument: str = 'config_file') -> None:
22+
def __init__(self, default_config_file: Path) -> None:
23+
self._config_parser = ConfigParser(interpolation=None)
2624
self._default_config_file = default_config_file
27-
self._config_file_argument = config_file_argument
25+
self.log = get_logger(type(self).__name__)
26+
27+
def load(self, argument_parser: ArgumentParser) -> Namespace:
28+
arguments = argument_parser.parse_known_args()[0]
29+
30+
self.log.info('Loading default configuration', config_file=str(self._default_config_file))
31+
loaded = self._config_parser.read(self._default_config_file)
32+
33+
if str(self._default_config_file) not in loaded:
34+
self.log.warn('Default configuration could not be loaded', config_file=str(self._default_config_file))
35+
36+
if custom_config_file := arguments.config:
37+
custom_config_file = Path(custom_config_file)
38+
39+
self.log.info('Loading custom configuration', config_file=str(custom_config_file))
40+
loaded = self._config_parser.read(custom_config_file)
41+
42+
if str(custom_config_file) not in loaded:
43+
self.log.warn('Custom configuration could not be loaded', config_file=str(custom_config_file))
44+
45+
configuration = dict(vars(arguments))
46+
47+
for section in self._config_parser.sections():
48+
configuration.update(dict(self._config_parser[section]))
49+
50+
self.log.info('Loading command line arguments', arguments=vars(arguments))
51+
cli_overrides = self._get_cli_overrides(argument_parser, arguments)
52+
configuration.update(cli_overrides)
53+
54+
self._sanitize_config(argument_parser, configuration)
55+
56+
self.log.info('Configuration loaded', configuration=configuration)
57+
58+
return Namespace(**configuration)
59+
60+
def _get_cli_overrides(self, parser: ArgumentParser, arguments: Namespace) -> dict[str, Any]:
61+
cli_overrides: dict[str, Any] = {}
62+
argv_tokens = set(sys.argv[1:])
63+
64+
if not argv_tokens:
65+
return cli_overrides
66+
67+
for action in parser._actions:
68+
if action.dest == 'help':
69+
continue
2870

29-
def load(self, arguments: dict[str, Any]) -> dict[str, Any]:
30-
parser = ConfigParser(interpolation=None)
71+
if action.option_strings: # has --flag or -f
72+
if any(opt in argv_tokens for opt in action.option_strings):
73+
cli_overrides[action.dest] = getattr(arguments, action.dest)
3174

32-
log.info('Loading default configuration', config_file=str(self._default_config_file))
33-
parser.read(self._default_config_file)
75+
return cli_overrides
3476

35-
if config_file := arguments.get(self._config_file_argument):
36-
custom_config_file = Path(config_file)
77+
def _sanitize_config(self, parser: ArgumentParser, config: dict[str, Any]) -> None:
78+
for config_key in config:
79+
action = self._find_action(parser, config_key)
80+
if action is None or action.default is None:
81+
continue
3782

38-
if os.path.exists(custom_config_file):
39-
log.info('Loading custom configuration', config_file=str(custom_config_file))
40-
parser.read(custom_config_file)
41-
else:
42-
try:
43-
log.info('Creating custom configuration using default', config_file=str(custom_config_file))
44-
copy_file(self._default_config_file, custom_config_file)
45-
except Exception as exception:
46-
log.warn('Failed to create custom configuration file', error=str(exception))
83+
if isinstance(action.default, bool):
84+
self._convert_bool(config, config_key)
85+
elif isinstance(action.default, int):
86+
self._convert_int(config, config_key)
87+
elif isinstance(action.default, float):
88+
self._convert_float(config, config_key)
4789

48-
configuration = {}
90+
for config_key in config:
91+
self.log.debug('Config', key=config_key, value=config[config_key], type=type(config[config_key]))
4992

50-
for section in parser.sections():
51-
configuration.update(dict(parser[section]))
93+
def _find_action(self, parser: ArgumentParser, config_key: str) -> Action:
94+
return next((a for a in parser._actions if a.dest == config_key), cast(Action, cast(object, None)))
5295

53-
log.info('Loading command line arguments', arguments=arguments)
54-
configuration.update(arguments)
96+
def _convert_bool(self, config: dict[str, Any], config_key: str) -> None:
97+
config[config_key] = str(config[config_key]).lower() in ('true', '1', 'yes')
5598

56-
log.info('Configuration loaded', configuration=configuration)
99+
def _convert_int(self, config: dict[str, Any], config_key: str) -> None:
100+
try:
101+
config[config_key] = int(config[config_key])
102+
except (TypeError, ValueError):
103+
pass
57104

58-
return configuration
105+
def _convert_float(self, config: dict[str, Any], config_key: str) -> None:
106+
try:
107+
config[config_key] = float(config[config_key])
108+
except (TypeError, ValueError):
109+
pass

tests/configLoaderTest.py

Lines changed: 174 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,215 @@
1-
import os.path
1+
import sys
22
import unittest
3+
from argparse import ArgumentParser
34
from pathlib import Path
45
from unittest import TestCase
6+
from unittest.mock import patch
57

68
from context_logger import setup_logging
79

810
from common_utility import delete_file, ConfigLoader, copy_file
911
from tests import TEST_RESOURCE_ROOT, TEST_FILE_SYSTEM_ROOT
1012

13+
DEFAULT_CONFIG_FILE = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf.default'
14+
1115

1216
class ConfigLoaderTest(TestCase):
1317

1418
@classmethod
1519
def setUpClass(cls):
1620
setup_logging('python-common-utility', 'DEBUG', warn_on_overwrite=False)
21+
copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf.default', DEFAULT_CONFIG_FILE)
1722

1823
def setUp(self):
1924
print()
2025
delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf')
26+
delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.conf')
27+
delete_file(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.invalid.conf')
28+
29+
def _create_argument_parser(self) -> ArgumentParser:
30+
argument_parser = ArgumentParser()
31+
argument_parser.add_argument('--config', default=None)
32+
argument_parser.add_argument('--config-key1', default=None)
33+
argument_parser.add_argument('--config-key2', default=None)
34+
argument_parser.add_argument('--example-key1', default=None)
35+
argument_parser.add_argument('--example-key2', default=None)
36+
return argument_parser
37+
38+
def test_load_config_when_default_config_file_could_not_be_loaded(self):
39+
# Given
40+
config_loader = ConfigLoader(Path('invalid/path/example.conf.default'))
41+
argument_parser = self._create_argument_parser()
42+
43+
# When
44+
with patch.object(sys, 'argv', ['test', '--config-key1', 'new_value1']):
45+
result = config_loader.load(argument_parser)
46+
47+
# Then
48+
self.assertEqual('new_value1', result.config_key1)
49+
50+
def test_load_config_when_no_custom_config_file_specified(self):
51+
# Given
52+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
53+
argument_parser = self._create_argument_parser()
54+
55+
# When
56+
with patch.object(sys, 'argv', ['test', '--config-key1', 'new_value1']):
57+
result = config_loader.load(argument_parser)
58+
59+
# Then
60+
self.assertEqual('new_value1', result.config_key1)
61+
self.assertEqual('value2', result.config_key2)
62+
self.assertEqual('example1', result.example_key1)
63+
self.assertEqual('example2', result.example_key2)
64+
65+
def test_load_config_when_custom_config_file_specified(self):
66+
# Given
67+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
68+
argument_parser = self._create_argument_parser()
69+
config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf'
70+
71+
copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file)
72+
73+
# When
74+
with patch.object(sys, 'argv', ['test', '--config', config_file, '--example-key1', 'new_example1']):
75+
result = config_loader.load(argument_parser)
2176

22-
def test_load_config_when_custom_configuration_not_exists(self):
77+
# Then
78+
self.assertEqual('value1', result.config_key1)
79+
self.assertEqual('value3', result.config_key2)
80+
self.assertEqual('new_example1', result.example_key1)
81+
self.assertEqual('example4', result.example_key2)
82+
83+
def test_load_config_when_custom_config_file_could_not_be_loaded(self):
2384
# Given
24-
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf')
25-
arguments = {
26-
'config_file': f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf',
27-
'config_key1': 'new_value1',
28-
}
85+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
86+
argument_parser = self._create_argument_parser()
87+
# When
88+
with patch.object(sys, 'argv',
89+
['test', '--config', 'invalid/path/example.conf', '--example-key1', 'new_example1']):
90+
result = config_loader.load(argument_parser)
91+
92+
# Then
93+
self.assertEqual('value2', result.config_key2)
94+
self.assertEqual('new_example1', result.example_key1)
95+
self.assertEqual('example2', result.example_key2)
96+
97+
def test_load_config_when_parser_default_values_defined_but_not_passed(self):
98+
# Given
99+
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default')
100+
argument_parser = ArgumentParser()
101+
argument_parser.add_argument('--config', default=None)
102+
argument_parser.add_argument('--config-key1', default='cli_default_value1')
103+
argument_parser.add_argument('--example-key1', default='cli_default_example1')
104+
config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf'
105+
106+
copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file)
29107

30108
# When
31-
result = config_loader.load(arguments)
109+
with patch.object(sys, 'argv', ['test', '--config', config_file]):
110+
result = config_loader.load(argument_parser)
32111

33112
# Then
34-
self.assertTrue(os.path.exists(f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf'))
35-
self.assertEqual('new_value1', result['config_key1'])
36-
self.assertEqual('value2', result['config_key2'])
37-
self.assertEqual('example1', result['example_key1'])
38-
self.assertEqual('example2', result['example_key2'])
113+
self.assertEqual('value1', result.config_key1)
114+
self.assertEqual('example3', result.example_key1)
39115

40-
def test_load_config_when_custom_configuration_exists(self):
116+
def test_load_config_when_short_option_cli_override_is_passed(self):
41117
# Given
42-
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf')
43-
arguments = {
44-
'config_file': f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf',
45-
'example_key1': 'new_example1',
46-
}
118+
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default')
119+
argument_parser = ArgumentParser()
120+
argument_parser.add_argument('--config', default=None)
121+
argument_parser.add_argument('--example-key2', '-e2', default=None)
122+
config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf'
47123

48-
copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', f'{TEST_FILE_SYSTEM_ROOT}/etc/example.conf')
124+
copy_file(f'{TEST_RESOURCE_ROOT}/config/example.conf', config_file)
49125

50126
# When
51-
result = config_loader.load(arguments)
127+
with patch.object(sys, 'argv', ['test', '--config', config_file, '-e2', 'new_example2']):
128+
result = config_loader.load(argument_parser)
52129

53130
# Then
54-
self.assertEqual('value1', result['config_key1'])
55-
self.assertEqual('value3', result['config_key2'])
56-
self.assertEqual('new_example1', result['example_key1'])
57-
self.assertEqual('example4', result['example_key2'])
131+
self.assertEqual('new_example2', result.example_key2)
58132

59-
def test_load_config_when_fail_to_create_custom_configuration(self):
133+
def test_load_config_when_long_option_cli_override_passed_using_equal_sign(self):
60134
# Given
61-
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.default.conf')
62-
arguments = {
63-
'config_file': '/invalid/path/to/example.conf',
64-
'config_key1': 'new_value1',
65-
}
135+
config_loader = ConfigLoader(Path(TEST_RESOURCE_ROOT) / 'config' / 'example.conf.default')
136+
argument_parser = self._create_argument_parser()
137+
138+
# When
139+
with patch.object(sys, 'argv', ['test', '--config-key1=new_value1']):
140+
result = config_loader.load(argument_parser)
141+
142+
# Then
143+
self.assertEqual('value1', result.config_key1)
144+
145+
def test_get_cli_overrides_when_no_cli_arguments(self):
146+
# Given
147+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
148+
argument_parser = self._create_argument_parser()
149+
arguments = argument_parser.parse_args([])
150+
151+
# When
152+
with patch.object(sys, 'argv', ['test']):
153+
result = config_loader._get_cli_overrides(argument_parser, arguments)
154+
155+
# Then
156+
self.assertEqual({}, result)
157+
158+
def test_get_cli_overrides_when_argument_has_no_option_string(self):
159+
# Given
160+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
161+
argument_parser = ArgumentParser()
162+
argument_parser.add_argument('--config', default=None)
163+
argument_parser.add_argument('input_file')
164+
arguments = argument_parser.parse_args(['input.txt'])
165+
166+
# When
167+
with patch.object(sys, 'argv', ['test', 'input.txt']):
168+
result = config_loader._get_cli_overrides(argument_parser, arguments)
169+
170+
# Then
171+
self.assertEqual({}, result)
172+
173+
def test_load_config_when_type_values_present_then_sanitize(self):
174+
# Given
175+
config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.conf'
176+
Path(config_file).write_text('[types]\nfeature_enabled = true\nretry_count = 7\ntimeout = 1.5\n',
177+
encoding='utf-8')
178+
179+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
180+
argument_parser = ArgumentParser()
181+
argument_parser.add_argument('--config', default=None)
182+
argument_parser.add_argument('--feature-enabled', default=False)
183+
argument_parser.add_argument('--retry-count', default=0)
184+
argument_parser.add_argument('--timeout', default=0.0)
185+
186+
# When
187+
with patch.object(sys, 'argv', ['test', '--config', config_file]):
188+
result = config_loader.load(argument_parser)
189+
190+
# Then
191+
self.assertTrue(result.feature_enabled)
192+
self.assertEqual(7, result.retry_count)
193+
self.assertEqual(1.5, result.timeout)
194+
195+
def test_load_config_when_invalid_numeric_values_present_then_keep_original(self):
196+
# Given
197+
config_file = f'{TEST_FILE_SYSTEM_ROOT}/etc/example.types.invalid.conf'
198+
Path(config_file).write_text('[types]\nretry_count = invalid\ntimeout = invalid\n', encoding='utf-8')
199+
200+
config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE))
201+
argument_parser = ArgumentParser()
202+
argument_parser.add_argument('--config', default=None)
203+
argument_parser.add_argument('--retry-count', default=0)
204+
argument_parser.add_argument('--timeout', default=0.0)
66205

67206
# When
68-
result = config_loader.load(arguments)
207+
with patch.object(sys, 'argv', ['test', '--config', config_file]):
208+
result = config_loader.load(argument_parser)
69209

70210
# Then
71-
self.assertEqual('new_value1', result['config_key1'])
72-
self.assertEqual('value2', result['config_key2'])
73-
self.assertEqual('example1', result['example_key1'])
74-
self.assertEqual('example2', result['example_key2'])
211+
self.assertEqual('invalid', result.retry_count)
212+
self.assertEqual('invalid', result.timeout)
75213

76214

77215
if __name__ == '__main__':

0 commit comments

Comments
 (0)