diff --git a/src/azure-cli-core/azure/cli/core/__init__.py b/src/azure-cli-core/azure/cli/core/__init__.py index 5b348c393f0..aec4d2b0dd6 100644 --- a/src/azure-cli-core/azure/cli/core/__init__.py +++ b/src/azure-cli-core/azure/cli/core/__init__.py @@ -36,8 +36,9 @@ ALWAYS_LOADED_MODULES = [] # Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core. ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next'] -# Timeout (in seconds) for loading a single module. Acts as a safety valve to prevent indefinite hangs -MODULE_LOAD_TIMEOUT_SECONDS = 60 +# Timeout (in seconds) for loading command modules. +# Set via core.module_load_timeout. Use 0 or a negative value to disable timeout. +DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS = 120 # Maximum number of worker threads for parallel module loading. MAX_WORKER_THREAD_COUNT = 4 @@ -203,8 +204,16 @@ def _configure_style(self): format_styled_text.theme = theme +class ModuleLoadTimeoutError(TimeoutError): + """Raised when a command module does not finish loading within timeout.""" + + def __init__(self, module_name, timeout_seconds): + super().__init__("Module '{}' load timeout after {} seconds".format(module_name, timeout_seconds)) + + class ModuleLoadResult: # pylint: disable=too-few-public-methods - def __init__(self, module_name, command_table, group_table, elapsed_time, error=None, traceback_str=None, command_loader=None): + def __init__(self, module_name, command_table, group_table, elapsed_time, + error=None, traceback_str=None, command_loader=None): self.module_name = module_name self.command_table = command_table self.group_table = group_table @@ -264,6 +273,8 @@ def load_command_table(self, args): from azure.cli.core.breaking_change import ( import_core_breaking_changes, import_extension_breaking_changes) + module_load_timed_out = False + def _update_command_table_from_modules(args, command_modules=None): """Loads command tables from modules and merge into the main command table. @@ -290,9 +301,12 @@ def _update_command_table_from_modules(args, command_modules=None): except ImportError as e: logger.warning(e) + nonlocal module_load_timed_out + start_time = timeit.default_timer() logger.debug("Loading command modules...") - results = self._load_modules(args, command_modules) + results, timed_out = self._load_modules(args, command_modules) + module_load_timed_out = module_load_timed_out or timed_out count, cumulative_group_count, cumulative_command_count = \ self._process_results_with_timing(results) @@ -524,9 +538,11 @@ def _get_extension_suppressions(mod_loaders): _update_command_table_from_extensions(ext_suppressions) logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table)) - if use_command_index: + if use_command_index and not module_load_timed_out: command_index.update(self.command_table) self._cache_help_index(command_index) + elif use_command_index and module_load_timed_out: + logger.warning("Skip command index update because module loading timed out.") return self.command_table @@ -622,16 +638,21 @@ def load_arguments(self, command=None): loader._update_command_definitions() # pylint: disable=protected-access def _load_modules(self, args, command_modules): - """Load command modules using ThreadPoolExecutor with timeout protection.""" + """Load command modules using ThreadPoolExecutor.""" from azure.cli.core.commands import BLOCKED_MODS + from azure.cli.core import telemetry + timeout_seconds = self._get_module_load_timeout_seconds() results = [] + timed_out = False + processed_futures = set() with ThreadPoolExecutor(max_workers=MAX_WORKER_THREAD_COUNT) as executor: future_to_module = {executor.submit(self._load_single_module, mod, args): mod for mod in command_modules if mod not in BLOCKED_MODS} try: - for future in concurrent.futures.as_completed(future_to_module, timeout=MODULE_LOAD_TIMEOUT_SECONDS): + for future in concurrent.futures.as_completed(future_to_module, timeout=timeout_seconds): + processed_futures.add(future) try: result = future.result() results.append(result) @@ -644,7 +665,11 @@ def _load_modules(self, args, command_modules): logger.warning("Module '%s' load failed with unexpected exception: %s", mod, ex) results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) except concurrent.futures.TimeoutError: + timed_out = True + telemetry.set_thread_timeout(timeout_seconds) for future, mod in future_to_module.items(): + if future in processed_futures: + continue if future.done(): try: result = future.result() @@ -653,11 +678,33 @@ def _load_modules(self, args, command_modules): logger.warning("Module '%s' load failed: %s", mod, ex) results.append(ModuleLoadResult(mod, {}, {}, 0, ex)) else: - logger.warning("Module '%s' load timeout after %s seconds", mod, MODULE_LOAD_TIMEOUT_SECONDS) - results.append(ModuleLoadResult(mod, {}, {}, 0, - Exception(f"Module '{mod}' load timeout"))) + logger.warning("Module '%s' load timeout after %s seconds", mod, timeout_seconds) + results.append(ModuleLoadResult( + mod, + {}, + {}, + 0, + ModuleLoadTimeoutError(mod, timeout_seconds) + )) + + return results, timed_out + + def _get_module_load_timeout_seconds(self): + raw_timeout = self.cli_ctx.config.get('core', 'module_load_timeout', str(DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS)) + try: + timeout_seconds = int(raw_timeout) + except (TypeError, ValueError): + logger.warning("Invalid core.module_load_timeout value '%s'. Using default %s seconds.", + raw_timeout, DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS) + timeout_seconds = DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS + + if timeout_seconds <= 0: + logger.debug("Module load timeout disabled by core.module_load_timeout=%s", timeout_seconds) + return None - return results + logger.debug("Module load timeout set to %s seconds (core.module_load_timeout=%s).", + timeout_seconds, raw_timeout) + return timeout_seconds def _load_single_module(self, mod, args): from azure.cli.core.breaking_change import import_module_breaking_changes @@ -677,6 +724,10 @@ def _handle_module_load_error(self, result): """Handle errors that occurred during module loading.""" from azure.cli.core import telemetry + if isinstance(result.error, ModuleLoadTimeoutError): + logger.warning("Skip per-module timeout fault telemetry for '%s'.", result.module_name) + return + logger.error("Error loading command module '%s': %s", result.module_name, result.error) telemetry.set_exception(exception=result.error, fault_type='module-load-error-' + result.module_name, diff --git a/src/azure-cli-core/azure/cli/core/telemetry.py b/src/azure-cli-core/azure/cli/core/telemetry.py index 714bd751263..c1c057ef139 100644 --- a/src/azure-cli-core/azure/cli/core/telemetry.py +++ b/src/azure-cli-core/azure/cli/core/telemetry.py @@ -80,6 +80,7 @@ def __init__(self, correlation_id=None, application=None): self.enable_broker_on_windows = None self.msal_telemetry = None self.login_experience_v2 = None + self.thread_timeout = None def add_event(self, name, properties): for key in self.instrumentation_key: @@ -239,6 +240,11 @@ def _get_azure_cli_properties(self): set_custom_properties(result, 'EnableBrokerOnWindows', str(self.enable_broker_on_windows)) set_custom_properties(result, 'MsalTelemetry', self.msal_telemetry) set_custom_properties(result, 'LoginExperienceV2', str(self.login_experience_v2)) + set_custom_properties( + result, + 'ThreadTimeout', + self.thread_timeout + ) return result @@ -446,6 +452,11 @@ def set_command_index_rebuild_triggered(is_cmd_idx_rebuild_triggered=False): _session.is_cmd_idx_rebuild_triggered = is_cmd_idx_rebuild_triggered +@decorators.suppress_all_exceptions() +def set_thread_timeout(timeout_seconds): + _session.thread_timeout = 'module loading failed with timeout seconds: {}'.format(timeout_seconds) + + @decorators.suppress_all_exceptions() def set_command_details(command, output_type=None, parameters=None, extension_name=None, extension_version=None, command_preserve_casing=None): diff --git a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py index dbab0f61eb4..963ea281acf 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_command_registration.py @@ -9,7 +9,7 @@ import unittest from collections import namedtuple -from azure.cli.core import AzCommandsLoader, MainCommandsLoader +from azure.cli.core import AzCommandsLoader, MainCommandsLoader, ModuleLoadResult, ModuleLoadTimeoutError from azure.cli.core.commands import ExtensionCommandSource from azure.cli.core.extension import EXTENSIONS_MOD_PREFIX from azure.cli.core.mock import DummyCli @@ -274,22 +274,31 @@ def test_cmd_to_loader_map_populated_after_parallel_loading(self): # Load all commands (triggers parallel module loading) cmd_tbl = loader.load_command_table(None) - + # Verify EVERY command in command_table has an entry in cmd_to_loader_map # This is exactly what azdev does before it hits KeyError for cmd_name in cmd_tbl: # This should NOT raise KeyError - self.assertIn(cmd_name, loader.cmd_to_loader_map, - f"Command '{cmd_name}' missing from cmd_to_loader_map - " - f"would cause KeyError in azdev command-change meta-export") - + self.assertIn( + cmd_name, + loader.cmd_to_loader_map, + f"Command '{cmd_name}' missing from cmd_to_loader_map - " + f"would cause KeyError in azdev command-change meta-export" + ) + # Verify the entry is a list with at least one loader loaders = loader.cmd_to_loader_map[cmd_name] - self.assertIsInstance(loaders, list, - f"cmd_to_loader_map['{cmd_name}'] should be a list") - self.assertGreater(len(loaders), 0, - f"cmd_to_loader_map['{cmd_name}'] should have at least one loader") - + self.assertIsInstance( + loaders, + list, + f"cmd_to_loader_map['{cmd_name}'] should be a list" + ) + self.assertGreater( + len(loaders), + 0, + f"cmd_to_loader_map['{cmd_name}'] should have at least one loader" + ) + # Verify all expected commands are present expected_commands = {'hello mod-only', 'hello overridden', 'extra final', 'hello ext-only'} actual_commands = set(cmd_tbl.keys()) @@ -424,6 +433,28 @@ def update_and_check_index(): del INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE] del INDEX[CommandIndex._COMMAND_INDEX] + @mock.patch('importlib.import_module', _mock_import_lib) + @mock.patch('pkgutil.iter_modules', _mock_iter_modules) + @mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader) + @mock.patch('azure.cli.core.extension.get_extension_modname', _mock_get_extension_modname) + @mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions) + def test_command_index_not_updated_on_module_load_timeout(self): + from azure.cli.core._session import INDEX + from azure.cli.core import CommandIndex, __version__ + + cli = DummyCli() + loader = cli.commands_loader + + sentinel_index = {'sentinel': ['azure.cli.command_modules.sentinel']} + INDEX[CommandIndex._COMMAND_INDEX_VERSION] = __version__ + INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE] = cli.cloud.profile + INDEX[CommandIndex._COMMAND_INDEX] = sentinel_index.copy() + + with mock.patch.object(loader, '_load_modules', return_value=([], True)): + loader.load_command_table(None) + + self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], sentinel_index) + @mock.patch('importlib.import_module', _mock_import_lib) @mock.patch('pkgutil.iter_modules', _mock_iter_modules) @mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader) @@ -439,14 +470,14 @@ def test_command_index_always_loaded_extensions(self): index.invalidate() # Test azext_always_loaded is loaded when command index is rebuilt - with mock.patch.object(azure.cli.core,'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']): + with mock.patch.object(azure.cli.core, 'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']): loader.load_command_table(["hello", "mod-only"]) self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER") TestCommandRegistration.test_hook = [] # Test azext_always_loaded is loaded when command index is used - with mock.patch.object(azure.cli.core,'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']): + with mock.patch.object(azure.cli.core, 'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']): loader.load_command_table(["hello", "mod-only"]) self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER") @@ -479,6 +510,20 @@ def test_command_index_positional_argument(self): self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], self.expected_command_index) self.assertEqual(list(cmd_tbl), ['extra final']) + @mock.patch('azure.cli.core.telemetry.set_exception') + def test_timeout_module_load_error_does_not_emit_fault_telemetry(self, mock_set_exception): + cli = DummyCli() + loader = cli.commands_loader + + timeout_result = ModuleLoadResult('timeout_mod', {}, {}, 0, + ModuleLoadTimeoutError('timeout_mod', 1)) + loader._handle_module_load_error(timeout_result) + mock_set_exception.assert_not_called() + + regular_error_result = ModuleLoadResult('bad_mod', {}, {}, 0, Exception('import failed')) + loader._handle_module_load_error(regular_error_result) + mock_set_exception.assert_called_once() + def test_argument_with_overrides(self): global_vm_name_type = CLIArgumentType( diff --git a/src/azure-cli-core/azure/cli/core/tests/test_telemetry.py b/src/azure-cli-core/azure/cli/core/tests/test_telemetry.py index 0077f2eeb28..60582a7b714 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_telemetry.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_telemetry.py @@ -8,6 +8,10 @@ class TestCoreTelemetry(unittest.TestCase): + def setUp(self): + from azure.cli.core import telemetry + telemetry._session.__init__() # pylint: disable=unnecessary-dunder-call + def test_suppress_all_exceptions(self): self._impl(Exception, 'fallback') self._impl(Exception, None) @@ -116,5 +120,23 @@ def test_command_preserve_casing_telemetry(self, mock_get_version): azure_cli_props = session._get_azure_cli_properties() self.assertIn('Context.Default.AzureCLI.CommandPreserveCasing', azure_cli_props) - self.assertEqual(azure_cli_props['Context.Default.AzureCLI.CommandPreserveCasing'], + self.assertEqual(azure_cli_props['Context.Default.AzureCLI.CommandPreserveCasing'], expected_casing) + + def test_thread_timeout_telemetry_property(self): + from azure.cli.core import telemetry + + original_value = telemetry._session.thread_timeout + try: + telemetry._session.thread_timeout = None + props = telemetry._session._get_azure_cli_properties() + self.assertNotIn('Context.Default.AzureCLI.ThreadTimeout', props) + + telemetry.set_thread_timeout(60) + props = telemetry._session._get_azure_cli_properties() + self.assertEqual( + props['Context.Default.AzureCLI.ThreadTimeout'], + 'module loading failed with timeout seconds: 60' + ) + finally: + telemetry._session.thread_timeout = original_value