Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions src/azure-cli-core/azure/cli/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
ALWAYS_LOADED_MODULES = []
# Extensions that will always be loaded if installed. They don't expose commands but hook into CLI core.
ALWAYS_LOADED_EXTENSIONS = ['azext_ai_examples', 'azext_next']
# Timeout (in seconds) for loading a single module. Acts as a safety valve to prevent indefinite hangs
MODULE_LOAD_TIMEOUT_SECONDS = 60
# Timeout (in seconds) for loading command modules.
# Set via core.module_load_timeout. Use 0 or a negative value to disable timeout.
DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS = 120
# Maximum number of worker threads for parallel module loading.
MAX_WORKER_THREAD_COUNT = 4

Expand Down Expand Up @@ -203,8 +204,16 @@ def _configure_style(self):
format_styled_text.theme = theme


class ModuleLoadTimeoutError(TimeoutError):
"""Raised when a command module does not finish loading within timeout."""

def __init__(self, module_name, timeout_seconds):
super().__init__("Module '{}' load timeout after {} seconds".format(module_name, timeout_seconds))


class ModuleLoadResult: # pylint: disable=too-few-public-methods
def __init__(self, module_name, command_table, group_table, elapsed_time, error=None, traceback_str=None, command_loader=None):
def __init__(self, module_name, command_table, group_table, elapsed_time,
error=None, traceback_str=None, command_loader=None):
self.module_name = module_name
self.command_table = command_table
self.group_table = group_table
Expand Down Expand Up @@ -264,6 +273,8 @@ def load_command_table(self, args):
from azure.cli.core.breaking_change import (
import_core_breaking_changes, import_extension_breaking_changes)

module_load_timed_out = False

def _update_command_table_from_modules(args, command_modules=None):
"""Loads command tables from modules and merge into the main command table.

Expand All @@ -290,9 +301,12 @@ def _update_command_table_from_modules(args, command_modules=None):
except ImportError as e:
logger.warning(e)

nonlocal module_load_timed_out

start_time = timeit.default_timer()
logger.debug("Loading command modules...")
results = self._load_modules(args, command_modules)
results, timed_out = self._load_modules(args, command_modules)
module_load_timed_out = module_load_timed_out or timed_out

count, cumulative_group_count, cumulative_command_count = \
self._process_results_with_timing(results)
Expand Down Expand Up @@ -524,9 +538,11 @@ def _get_extension_suppressions(mod_loaders):
_update_command_table_from_extensions(ext_suppressions)
logger.debug("Loaded %d groups, %d commands.", len(self.command_group_table), len(self.command_table))

if use_command_index:
if use_command_index and not module_load_timed_out:
command_index.update(self.command_table)
self._cache_help_index(command_index)
elif use_command_index and module_load_timed_out:
logger.warning("Skip command index update because module loading timed out.")

return self.command_table

Expand Down Expand Up @@ -622,16 +638,21 @@ def load_arguments(self, command=None):
loader._update_command_definitions() # pylint: disable=protected-access

def _load_modules(self, args, command_modules):
"""Load command modules using ThreadPoolExecutor with timeout protection."""
"""Load command modules using ThreadPoolExecutor."""
from azure.cli.core.commands import BLOCKED_MODS
from azure.cli.core import telemetry

timeout_seconds = self._get_module_load_timeout_seconds()
results = []
timed_out = False
processed_futures = set()
with ThreadPoolExecutor(max_workers=MAX_WORKER_THREAD_COUNT) as executor:
future_to_module = {executor.submit(self._load_single_module, mod, args): mod
for mod in command_modules if mod not in BLOCKED_MODS}

try:
for future in concurrent.futures.as_completed(future_to_module, timeout=MODULE_LOAD_TIMEOUT_SECONDS):
for future in concurrent.futures.as_completed(future_to_module, timeout=timeout_seconds):
processed_futures.add(future)
try:
result = future.result()
results.append(result)
Expand All @@ -644,7 +665,11 @@ def _load_modules(self, args, command_modules):
logger.warning("Module '%s' load failed with unexpected exception: %s", mod, ex)
results.append(ModuleLoadResult(mod, {}, {}, 0, ex))
except concurrent.futures.TimeoutError:
timed_out = True
telemetry.set_thread_timeout(timeout_seconds)
for future, mod in future_to_module.items():
if future in processed_futures:
continue
if future.done():
try:
result = future.result()
Expand All @@ -653,11 +678,33 @@ def _load_modules(self, args, command_modules):
logger.warning("Module '%s' load failed: %s", mod, ex)
results.append(ModuleLoadResult(mod, {}, {}, 0, ex))
else:
logger.warning("Module '%s' load timeout after %s seconds", mod, MODULE_LOAD_TIMEOUT_SECONDS)
results.append(ModuleLoadResult(mod, {}, {}, 0,
Exception(f"Module '{mod}' load timeout")))
logger.warning("Module '%s' load timeout after %s seconds", mod, timeout_seconds)
results.append(ModuleLoadResult(
mod,
{},
{},
0,
ModuleLoadTimeoutError(mod, timeout_seconds)
))

return results, timed_out

def _get_module_load_timeout_seconds(self):
raw_timeout = self.cli_ctx.config.get('core', 'module_load_timeout', str(DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS))
try:
timeout_seconds = int(raw_timeout)
except (TypeError, ValueError):
logger.warning("Invalid core.module_load_timeout value '%s'. Using default %s seconds.",
raw_timeout, DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS)
timeout_seconds = DEFAULT_MODULE_LOAD_TIMEOUT_SECONDS

if timeout_seconds <= 0:
logger.debug("Module load timeout disabled by core.module_load_timeout=%s", timeout_seconds)
return None

return results
logger.debug("Module load timeout set to %s seconds (core.module_load_timeout=%s).",
timeout_seconds, raw_timeout)
return timeout_seconds

def _load_single_module(self, mod, args):
from azure.cli.core.breaking_change import import_module_breaking_changes
Expand All @@ -677,6 +724,10 @@ def _handle_module_load_error(self, result):
"""Handle errors that occurred during module loading."""
from azure.cli.core import telemetry

if isinstance(result.error, ModuleLoadTimeoutError):
logger.warning("Skip per-module timeout fault telemetry for '%s'.", result.module_name)
return

logger.error("Error loading command module '%s': %s", result.module_name, result.error)
telemetry.set_exception(exception=result.error,
fault_type='module-load-error-' + result.module_name,
Expand Down
11 changes: 11 additions & 0 deletions src/azure-cli-core/azure/cli/core/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, correlation_id=None, application=None):
self.enable_broker_on_windows = None
self.msal_telemetry = None
self.login_experience_v2 = None
self.thread_timeout = None

def add_event(self, name, properties):
for key in self.instrumentation_key:
Expand Down Expand Up @@ -239,6 +240,11 @@ def _get_azure_cli_properties(self):
set_custom_properties(result, 'EnableBrokerOnWindows', str(self.enable_broker_on_windows))
set_custom_properties(result, 'MsalTelemetry', self.msal_telemetry)
set_custom_properties(result, 'LoginExperienceV2', str(self.login_experience_v2))
set_custom_properties(
result,
'ThreadTimeout',
self.thread_timeout
)

return result

Expand Down Expand Up @@ -446,6 +452,11 @@ def set_command_index_rebuild_triggered(is_cmd_idx_rebuild_triggered=False):
_session.is_cmd_idx_rebuild_triggered = is_cmd_idx_rebuild_triggered


@decorators.suppress_all_exceptions()
def set_thread_timeout(timeout_seconds):
_session.thread_timeout = 'module loading failed with timeout seconds: {}'.format(timeout_seconds)


@decorators.suppress_all_exceptions()
def set_command_details(command, output_type=None, parameters=None, extension_name=None,
extension_version=None, command_preserve_casing=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import unittest
from collections import namedtuple

from azure.cli.core import AzCommandsLoader, MainCommandsLoader
from azure.cli.core import AzCommandsLoader, MainCommandsLoader, ModuleLoadResult, ModuleLoadTimeoutError
from azure.cli.core.commands import ExtensionCommandSource
from azure.cli.core.extension import EXTENSIONS_MOD_PREFIX
from azure.cli.core.mock import DummyCli
Expand Down Expand Up @@ -274,22 +274,31 @@ def test_cmd_to_loader_map_populated_after_parallel_loading(self):

# Load all commands (triggers parallel module loading)
cmd_tbl = loader.load_command_table(None)

# Verify EVERY command in command_table has an entry in cmd_to_loader_map
# This is exactly what azdev does before it hits KeyError
for cmd_name in cmd_tbl:
# This should NOT raise KeyError
self.assertIn(cmd_name, loader.cmd_to_loader_map,
f"Command '{cmd_name}' missing from cmd_to_loader_map - "
f"would cause KeyError in azdev command-change meta-export")

self.assertIn(
cmd_name,
loader.cmd_to_loader_map,
f"Command '{cmd_name}' missing from cmd_to_loader_map - "
f"would cause KeyError in azdev command-change meta-export"
)

# Verify the entry is a list with at least one loader
loaders = loader.cmd_to_loader_map[cmd_name]
self.assertIsInstance(loaders, list,
f"cmd_to_loader_map['{cmd_name}'] should be a list")
self.assertGreater(len(loaders), 0,
f"cmd_to_loader_map['{cmd_name}'] should have at least one loader")

self.assertIsInstance(
loaders,
list,
f"cmd_to_loader_map['{cmd_name}'] should be a list"
)
self.assertGreater(
len(loaders),
0,
f"cmd_to_loader_map['{cmd_name}'] should have at least one loader"
)

# Verify all expected commands are present
expected_commands = {'hello mod-only', 'hello overridden', 'extra final', 'hello ext-only'}
actual_commands = set(cmd_tbl.keys())
Expand Down Expand Up @@ -424,6 +433,28 @@ def update_and_check_index():
del INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE]
del INDEX[CommandIndex._COMMAND_INDEX]

@mock.patch('importlib.import_module', _mock_import_lib)
@mock.patch('pkgutil.iter_modules', _mock_iter_modules)
@mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader)
@mock.patch('azure.cli.core.extension.get_extension_modname', _mock_get_extension_modname)
@mock.patch('azure.cli.core.extension.get_extensions', _mock_get_extensions)
def test_command_index_not_updated_on_module_load_timeout(self):
from azure.cli.core._session import INDEX
from azure.cli.core import CommandIndex, __version__

cli = DummyCli()
loader = cli.commands_loader

sentinel_index = {'sentinel': ['azure.cli.command_modules.sentinel']}
INDEX[CommandIndex._COMMAND_INDEX_VERSION] = __version__
INDEX[CommandIndex._COMMAND_INDEX_CLOUD_PROFILE] = cli.cloud.profile
INDEX[CommandIndex._COMMAND_INDEX] = sentinel_index.copy()

with mock.patch.object(loader, '_load_modules', return_value=([], True)):
loader.load_command_table(None)

self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], sentinel_index)

@mock.patch('importlib.import_module', _mock_import_lib)
@mock.patch('pkgutil.iter_modules', _mock_iter_modules)
@mock.patch('azure.cli.core.commands._load_command_loader', _mock_load_command_loader)
Expand All @@ -439,14 +470,14 @@ def test_command_index_always_loaded_extensions(self):
index.invalidate()

# Test azext_always_loaded is loaded when command index is rebuilt
with mock.patch.object(azure.cli.core,'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
with mock.patch.object(azure.cli.core, 'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
loader.load_command_table(["hello", "mod-only"])
self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER")

TestCommandRegistration.test_hook = []

# Test azext_always_loaded is loaded when command index is used
with mock.patch.object(azure.cli.core,'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
with mock.patch.object(azure.cli.core, 'ALWAYS_LOADED_EXTENSIONS', ['azext_always_loaded']):
loader.load_command_table(["hello", "mod-only"])
self.assertEqual(TestCommandRegistration.test_hook, "FAKE_HANDLER")

Expand Down Expand Up @@ -479,6 +510,20 @@ def test_command_index_positional_argument(self):
self.assertDictEqual(INDEX[CommandIndex._COMMAND_INDEX], self.expected_command_index)
self.assertEqual(list(cmd_tbl), ['extra final'])

@mock.patch('azure.cli.core.telemetry.set_exception')
def test_timeout_module_load_error_does_not_emit_fault_telemetry(self, mock_set_exception):
cli = DummyCli()
loader = cli.commands_loader

timeout_result = ModuleLoadResult('timeout_mod', {}, {}, 0,
ModuleLoadTimeoutError('timeout_mod', 1))
loader._handle_module_load_error(timeout_result)
mock_set_exception.assert_not_called()

regular_error_result = ModuleLoadResult('bad_mod', {}, {}, 0, Exception('import failed'))
loader._handle_module_load_error(regular_error_result)
mock_set_exception.assert_called_once()

def test_argument_with_overrides(self):

global_vm_name_type = CLIArgumentType(
Expand Down
24 changes: 23 additions & 1 deletion src/azure-cli-core/azure/cli/core/tests/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@


class TestCoreTelemetry(unittest.TestCase):
def setUp(self):
from azure.cli.core import telemetry
telemetry._session.__init__() # pylint: disable=unnecessary-dunder-call

def test_suppress_all_exceptions(self):
self._impl(Exception, 'fallback')
self._impl(Exception, None)
Expand Down Expand Up @@ -116,5 +120,23 @@ def test_command_preserve_casing_telemetry(self, mock_get_version):
azure_cli_props = session._get_azure_cli_properties()

self.assertIn('Context.Default.AzureCLI.CommandPreserveCasing', azure_cli_props)
self.assertEqual(azure_cli_props['Context.Default.AzureCLI.CommandPreserveCasing'],
self.assertEqual(azure_cli_props['Context.Default.AzureCLI.CommandPreserveCasing'],
expected_casing)

def test_thread_timeout_telemetry_property(self):
from azure.cli.core import telemetry

original_value = telemetry._session.thread_timeout
try:
telemetry._session.thread_timeout = None
props = telemetry._session._get_azure_cli_properties()
self.assertNotIn('Context.Default.AzureCLI.ThreadTimeout', props)

telemetry.set_thread_timeout(60)
props = telemetry._session._get_azure_cli_properties()
self.assertEqual(
props['Context.Default.AzureCLI.ThreadTimeout'],
'module loading failed with timeout seconds: 60'
)
finally:
telemetry._session.thread_timeout = original_value