diff --git a/aws_advanced_python_wrapper/cleanup.py b/aws_advanced_python_wrapper/cleanup.py index a25392d2..67ce284c 100644 --- a/aws_advanced_python_wrapper/cleanup.py +++ b/aws_advanced_python_wrapper/cleanup.py @@ -18,6 +18,8 @@ MonitoringThreadContainer from aws_advanced_python_wrapper.thread_pool_container import \ ThreadPoolContainer +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer def release_resources() -> None: @@ -25,3 +27,4 @@ def release_resources() -> None: MonitoringThreadContainer.clean_up() ThreadPoolContainer.release_resources() OpenedConnectionTracker.release_resources() + SlidingExpirationCacheContainer.release_resources() diff --git a/aws_advanced_python_wrapper/connection_provider.py b/aws_advanced_python_wrapper/connection_provider.py index 16ba12fe..56d6fc7b 100644 --- a/aws_advanced_python_wrapper/connection_provider.py +++ b/aws_advanced_python_wrapper/connection_provider.py @@ -14,7 +14,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Optional, Protocol, Tuple +from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, Optional, + Protocol, Tuple) if TYPE_CHECKING: from aws_advanced_python_wrapper.database_dialect import DatabaseDialect @@ -131,8 +132,8 @@ def connect( class ConnectionProviderManager: - _lock: Lock = Lock() - _conn_provider: Optional[ConnectionProvider] = None + _lock: ClassVar[Lock] = Lock() + _conn_provider: ClassVar[Optional[ConnectionProvider]] = None def __init__(self, default_provider: ConnectionProvider = DriverConnectionProvider()): self._default_provider: ConnectionProvider = default_provider diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py index 2db46ae5..2ce683b6 100644 --- a/aws_advanced_python_wrapper/custom_endpoint_plugin.py +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -42,8 +42,8 @@ from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import WrapperProperties from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryCounter, TelemetryFactory) @@ -232,11 +232,8 @@ class CustomEndpointPlugin(Plugin): or removing an instance in the custom endpoint. """ _SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name} - _CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute - _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS, - should_dispose_func=lambda _: True, - item_disposal_func=lambda monitor: monitor.close()) + _CACHE_CLEANUP_RATE_NS: ClassVar[int] = 60_000_000_000 # 1 minute + _MONITOR_CACHE_NAME: ClassVar[str] = "custom_endpoint_monitors" def __init__(self, plugin_service: PluginService, props: Properties): self._plugin_service = plugin_service @@ -255,6 +252,13 @@ def __init__(self, plugin_service: PluginService, props: Properties): telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter") + self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( + name=CustomEndpointPlugin._MONITOR_CACHE_NAME, + cleanup_interval_ns=CustomEndpointPlugin._CACHE_CLEANUP_RATE_NS, + should_dispose_func=lambda _: True, + item_disposal_func=lambda monitor: monitor.close() + ) + CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) @property @@ -298,7 +302,7 @@ def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor: host_info = cast('HostInfo', self._custom_endpoint_host_info) endpoint_id = cast('str', self._custom_endpoint_id) region = cast('str', self._region) - monitor = CustomEndpointPlugin._monitors.compute_if_absent( + monitor = self._monitors.compute_if_absent( host_info.host, lambda key: CustomEndpointMonitor( self._plugin_service, diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 96e90af0..1815abe5 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -695,6 +695,7 @@ def __init__(self, props: Properties, rds_helper: Optional[RdsUtils] = None): self._can_update: bool = False self._dialect: DatabaseDialect = UnknownDatabaseDialect() self._dialect_code: DialectCode = DialectCode.UNKNOWN + self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) @staticmethod def get_custom_dialect(): @@ -814,7 +815,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props) try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name), + self._thread_pool, timeout_sec, driver_dialect, conn)(dialect_candidate.is_dialect) diff --git a/aws_advanced_python_wrapper/driver_dialect.py b/aws_advanced_python_wrapper/driver_dialect.py index 1da64e7b..3683a435 100644 --- a/aws_advanced_python_wrapper/driver_dialect.py +++ b/aws_advanced_python_wrapper/driver_dialect.py @@ -51,6 +51,7 @@ class DriverDialect(ABC): def __init__(self, props: Properties): self._props = props + self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) @property def driver_name(self): @@ -138,7 +139,7 @@ def execute( if exec_timeout > 0: try: - execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func) + execute_with_timeout = timeout(self._thread_pool, exec_timeout)(exec_func) return execute_with_timeout() except TimeoutError as e: raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index b4abd577..5f9f0fdc 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -30,8 +30,8 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel) @@ -59,7 +59,7 @@ def __init__(self, plugin_service: PluginService, props: Properties): self._properties = props self._host_response_time_service: HostResponseTimeService = \ HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props)) - self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6 + self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 1_000_000 self._random_host_selector = RandomHostSelector() self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap() self._hosts: Tuple[HostInfo, ...] = () @@ -278,13 +278,10 @@ def _open_connection(self): class HostResponseTimeService: - _CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes - _CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute - _lock: Lock = Lock() - _monitoring_hosts: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS, - should_dispose_func=lambda monitor: True, - item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor)) + _CACHE_EXPIRATION_NS: ClassVar[int] = 10 * 60_000_000_000 # 10 minutes + _CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute + _CACHE_NAME: ClassVar[str] = "host_response_time_monitors" + _lock: ClassVar[Lock] = Lock() def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int): self._plugin_service = plugin_service @@ -292,7 +289,18 @@ def __init__(self, plugin_service: PluginService, props: Properties, interval_ms self._interval_ms = interval_ms self._hosts: Tuple[HostInfo, ...] = () self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() - self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge("frt.hosts.count", lambda: len(self._monitoring_hosts)) + + self._monitoring_hosts = SlidingExpirationCacheContainer.get_or_create_cache( + name=HostResponseTimeService._CACHE_NAME, + cleanup_interval_ns=HostResponseTimeService._CACHE_CLEANUP_NS, + should_dispose_func=lambda monitor: True, + item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor) + ) + + self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge( + "frt.hosts.count", + lambda: len(self._monitoring_hosts) + ) @property def hosts(self) -> Tuple[HostInfo, ...]: @@ -310,7 +318,7 @@ def _monitor_close(monitor: HostResponseTimeMonitor): pass def get_response_time(self, host_info: HostInfo) -> int: - monitor: Optional[HostResponseTimeMonitor] = HostResponseTimeService._monitoring_hosts.get(host_info.url) + monitor: Optional[HostResponseTimeMonitor] = self._monitoring_hosts.get(host_info.url) if monitor is None: return MAX_VALUE return monitor.response_time @@ -327,4 +335,4 @@ def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None: self._plugin_service, host, self._properties, - self._interval_ms), self._CACHE_EXPIRATION_NS) + self._interval_ms), HostResponseTimeService._CACHE_EXPIRATION_NS) diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index b971270c..d8297501 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -28,8 +28,8 @@ ClusterTopologyMonitor, ClusterTopologyMonitorImpl) from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -476,6 +476,7 @@ def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Prop self.instance_template: HostInfo = instance_template self._max_timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props) + self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) def _validate_host_pattern(self, host: str): if not self._rds_utils.is_dns_pattern_valid(host): @@ -507,7 +508,7 @@ def query_for_topology( an empty tuple will be returned. """ query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology) + self._thread_pool, self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology) x = query_for_topology_func_with_timeout(conn) return x @@ -570,7 +571,7 @@ def create_host( def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole: try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_role) + self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_role) result = cursor_execute_func_with_timeout(connection) if result is not None: is_reader = result[0] @@ -593,7 +594,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) -> """ cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_id) + self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_id) result = cursor_execute_func_with_timeout(connection) if result: host_id: str = result[0] @@ -608,7 +609,7 @@ def _get_host_id(self, conn: Connection): def get_writer_host_if_connected(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]: try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id) + self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id) result = cursor_execute_func_with_timeout(connection) if result: host_id: str = result[0] @@ -752,13 +753,9 @@ def _create_multi_az_host(self, record: Tuple, writer_id: str) -> HostInfo: class MonitoringRdsHostListProvider(RdsHostListProvider): - _CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 # 1 minute - _MONITOR_CLEANUP_NANO = 15 * 60 * 1_000_000_000 # 15 minutes - - _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, ClusterTopologyMonitor]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO, - should_dispose_func=lambda monitor: monitor.can_dispose(), - item_disposal_func=lambda monitor: monitor.close()) + _CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 # 1 minute + _MONITOR_CLEANUP_NANO: ClassVar[int] = 15 * 60 * 1_000_000_000 # 15 minutes + _MONITOR_CACHE_NAME: ClassVar[str] = "cluster_topology_monitors" def __init__( self, @@ -772,6 +769,13 @@ def __init__( self._high_refresh_rate_ns = ( WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get_int(self._props) * 1_000_000) + self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( + name=MonitoringRdsHostListProvider._MONITOR_CACHE_NAME, + cleanup_interval_ns=MonitoringRdsHostListProvider._CACHE_CLEANUP_NANO, + should_dispose_func=lambda monitor: monitor.can_dispose(), + item_disposal_func=lambda monitor: monitor.close() + ) + def _get_monitor(self) -> Optional[ClusterTopologyMonitor]: return self._monitors.compute_if_absent_with_disposal(self.get_cluster_id(), lambda k: ClusterTopologyMonitorImpl( @@ -803,7 +807,3 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) return () return monitor.force_refresh(should_verify_writer, timeout_sec) - - @staticmethod - def release_resources(): - MonitoringRdsHostListProvider._monitors.clear() diff --git a/aws_advanced_python_wrapper/host_monitoring_plugin.py b/aws_advanced_python_wrapper/host_monitoring_plugin.py index 69252d19..b8862161 100644 --- a/aws_advanced_python_wrapper/host_monitoring_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_plugin.py @@ -578,6 +578,9 @@ class MonitoringThreadContainer: _tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict() _executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor" + def __init__(self): + self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) + # This logic ensures that this class is a Singleton def __new__(cls, *args, **kwargs): if cls._instance is None: @@ -605,8 +608,7 @@ def _get_or_create_monitor(_) -> Monitor: raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone")) self._tasks_map.compute_if_absent( supplied_monitor, - lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name) - .submit(supplied_monitor.run)) + lambda _: self._thread_pool.submit(supplied_monitor.run)) return supplied_monitor if monitor is None: diff --git a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py index e20e9a7e..da1b0e53 100644 --- a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py @@ -36,8 +36,8 @@ PropertiesUtils, WrapperProperties) from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryCounter, TelemetryFactory, TelemetryTraceLevel) @@ -450,12 +450,8 @@ def close(self) -> None: class MonitorServiceV2: # 1 Minute to Nanoseconds - _CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 - - _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostMonitorV2]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO, - should_dispose_func=lambda monitor: monitor.can_dispose(), - item_disposal_func=lambda monitor: monitor.close()) + _CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 + _MONITOR_CACHE_NAME: ClassVar[str] = "host_monitors_v2" def __init__(self, plugin_service: PluginService): self._plugin_service: PluginService = plugin_service @@ -463,6 +459,13 @@ def __init__(self, plugin_service: PluginService): telemetry_factory = self._plugin_service.get_telemetry_factory() self._aborted_connections_counter = telemetry_factory.create_counter("efm2.connections.aborted") + self._monitors = SlidingExpirationCacheContainer.get_or_create_cache( + name=MonitorServiceV2._MONITOR_CACHE_NAME, + cleanup_interval_ns=MonitorServiceV2._CACHE_CLEANUP_NANO, + should_dispose_func=lambda monitor: monitor.can_dispose(), + item_disposal_func=lambda monitor: monitor.close() + ) + def start_monitoring( self, conn: Connection, diff --git a/aws_advanced_python_wrapper/limitless_plugin.py b/aws_advanced_python_wrapper/limitless_plugin.py index 1b8cea4b..22636dfa 100644 --- a/aws_advanced_python_wrapper/limitless_plugin.py +++ b/aws_advanced_python_wrapper/limitless_plugin.py @@ -33,8 +33,8 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ - SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryTraceLevel) from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils @@ -112,7 +112,7 @@ class LimitlessRouterMonitor: def __init__(self, plugin_service: PluginService, host_info: HostInfo, - limitless_router_cache: SlidingExpirationCacheWithCleanupThread, + limitless_router_cache, # SlidingExpirationCache from container limitless_router_cache_key: str, props: Properties, interval_ms: int): @@ -312,21 +312,27 @@ def is_any_router_available(self): class LimitlessRouterService: - _CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute - _limitless_router_cache: ClassVar[SlidingExpirationCacheWithCleanupThread[str, List[HostInfo]]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS) - - _limitless_router_monitor: ClassVar[SlidingExpirationCacheWithCleanupThread[str, LimitlessRouterMonitor]] = \ - SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS, - should_dispose_func=lambda monitor: True, - item_disposal_func=lambda monitor: monitor.close()) - + _CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute + _ROUTER_CACHE_NAME: ClassVar[str] = "limitless_router_cache" + _MONITOR_CACHE_NAME: ClassVar[str] = "limitless_monitor_cache" _force_get_limitless_routers_lock_map: ClassVar[ConcurrentDict[str, RLock]] = ConcurrentDict() def __init__(self, plugin_service: PluginService, query_helper: LimitlessQueryHelper): self._plugin_service = plugin_service self._query_helper = query_helper + self._limitless_router_cache = SlidingExpirationCacheContainer.get_or_create_cache( + name=LimitlessRouterService._ROUTER_CACHE_NAME, + cleanup_interval_ns=LimitlessRouterService._CACHE_CLEANUP_NS + ) + + self._limitless_router_monitor = SlidingExpirationCacheContainer.get_or_create_cache( + name=LimitlessRouterService._MONITOR_CACHE_NAME, + cleanup_interval_ns=LimitlessRouterService._CACHE_CLEANUP_NS, + should_dispose_func=lambda monitor: True, + item_disposal_func=lambda monitor: monitor.close() + ) + def establish_connection(self, context: LimitlessContext) -> None: context.set_limitless_routers(self._get_limitless_routers( self._plugin_service.host_list_provider.get_cluster_id(), context.get_props())) @@ -385,8 +391,8 @@ def establish_connection(self, context: LimitlessContext) -> None: def _get_limitless_routers(self, cluster_id: str, props: Properties) -> List[HostInfo]: # Convert milliseconds to nanoseconds cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 - LimitlessRouterService._limitless_router_cache.set_cleanup_interval_ns(cache_expiration_nano) - routers = LimitlessRouterService._limitless_router_cache.get(cluster_id) + self._limitless_router_cache.set_cleanup_interval_ns(cache_expiration_nano) + routers = self._limitless_router_cache.get(cluster_id) if routers is None: return [] return routers @@ -481,7 +487,7 @@ def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> Non lock.acquire() try: - limitless_routers = LimitlessRouterService._limitless_router_cache.get( + limitless_routers = self._limitless_router_cache.get( self._plugin_service.host_list_provider.get_cluster_id()) if limitless_routers is not None and len(limitless_routers) != 0: context.set_limitless_routers(limitless_routers) @@ -495,7 +501,7 @@ def _synchronously_get_limitless_routers(self, context: LimitlessContext) -> Non if new_limitless_routers is not None and len(new_limitless_routers) != 0: context.set_limitless_routers(new_limitless_routers) - LimitlessRouterService._limitless_router_cache.compute_if_absent( + self._limitless_router_cache.compute_if_absent( self._plugin_service.host_list_provider.get_cluster_id(), lambda _: new_limitless_routers, cache_expiration_nano @@ -516,11 +522,11 @@ def start_monitoring(self, host_info: HostInfo, cache_expiration_nano: int = WrapperProperties.LIMITLESS_MONITOR_DISPOSAL_TIME_MS.get_int(props) * 1_000_000 intervals_ms: int = WrapperProperties.LIMITLESS_INTERVAL_MILLIS.get_int(props) - LimitlessRouterService._limitless_router_monitor.compute_if_absent( + self._limitless_router_monitor.compute_if_absent( limitless_router_monitor_key, lambda _: LimitlessRouterMonitor(self._plugin_service, host_info, - LimitlessRouterService._limitless_router_cache, + self._limitless_router_cache, limitless_router_monitor_key, props, intervals_ms), cache_expiration_nano) @@ -530,4 +536,4 @@ def start_monitoring(self, host_info: HostInfo, def clear_cache(self) -> None: LimitlessRouterService._force_get_limitless_routers_lock_map.clear() - LimitlessRouterService._limitless_router_cache.clear() + self._limitless_router_cache.clear() diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index c2949ceb..132025f7 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -27,8 +27,6 @@ from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError from aws_advanced_python_wrapper.pep249_methods import DbApiMethod -from aws_advanced_python_wrapper.thread_pool_container import \ - ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import timeout from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, @@ -98,7 +96,7 @@ def is_closed(self, conn: Connection) -> bool: socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props) timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC is_connected_with_timeout = timeout( - ThreadPoolContainer.get_thread_pool(MySQLDriverDialect._executor_name), timeout_sec)(conn.is_connected) # type: ignore + self._thread_pool, timeout_sec)(conn.is_connected) # type: ignore try: return not is_connected_with_timeout() diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 8096c8b7..bb0fb907 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -351,6 +351,7 @@ def __init__( self._driver_dialect = driver_dialect self._database_dialect = self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) self._session_state_service = session_state_service if session_state_service is not None else SessionStateServiceImpl(self, props) + self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name) @property def all_hosts(self) -> Tuple[HostInfo, ...]: @@ -631,7 +632,7 @@ def fill_aliases(self, connection: Optional[Connection] = None, host_info: Optio try: timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props) cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(PluginServiceImpl._executor_name), timeout_sec, driver_dialect, connection)(self._fill_aliases) + self._thread_pool, timeout_sec, driver_dialect, connection)(self._fill_aliases) cursor_execute_func_with_timeout(connection, host_info) except TimeoutError as e: raise QueryTimeoutError(Messages.get("PluginServiceImpl.FillAliasesTimeout")) from e diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index ecb5fcab..965ef1e2 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -484,3 +484,6 @@ XRayTelemetryContext.TraceID="[XRayTelemetryContext] Telemetry '{}' trace ID: {} XRayTelemetryFactory.MetricsNotSupported="[XRayTelemetryFactory] XRay doesn't support metrics." XRayTelemetryFactory.WrongParameterType="[XRayTelemetryFactory] Wrong parameter type: {}" + +SlidingExpirationCacheContainer.ErrorReleasingCache=[SlidingExpirationCacheContainer] Error releasing cache '{}': {} +SlidingExpirationCacheContainer.ErrorDuringCleanup=[SlidingExpirationCacheContainer] Error during cleanup of cache '{}': {} diff --git a/aws_advanced_python_wrapper/thread_pool_container.py b/aws_advanced_python_wrapper/thread_pool_container.py index 6c9cf905..9254dbb2 100644 --- a/aws_advanced_python_wrapper/thread_pool_container.py +++ b/aws_advanced_python_wrapper/thread_pool_container.py @@ -14,7 +14,7 @@ import threading from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Optional +from typing import ClassVar, Dict, List, Optional from aws_advanced_python_wrapper.utils.log import Logger @@ -27,9 +27,9 @@ class ThreadPoolContainer: Provides static methods for getting, creating, and releasing thread pools. """ - _pools: Dict[str, ThreadPoolExecutor] = {} - _lock: threading.Lock = threading.Lock() - _default_max_workers: Optional[int] = None # Uses Python's default + _pools: ClassVar[Dict[str, ThreadPoolExecutor]] = {} + _lock: ClassVar[threading.Lock] = threading.Lock() + _default_max_workers: ClassVar[Optional[int]] = None # Uses Python's default @classmethod def get_thread_pool( diff --git a/aws_advanced_python_wrapper/utils/cache_map.py b/aws_advanced_python_wrapper/utils/cache_map.py index dada8166..5c0f0282 100644 --- a/aws_advanced_python_wrapper/utils/cache_map.py +++ b/aws_advanced_python_wrapper/utils/cache_map.py @@ -30,7 +30,8 @@ def __init__(self): self._lock = threading.RLock() def __len__(self): - return len(self._cache) + with self._lock: + return len(self._cache) def get(self, key: K) -> Optional[V]: with self._lock: @@ -62,15 +63,18 @@ def get_with_default(self, key: K, value_if_absent: V, item_expiration_ns: int) return None def put(self, key: K, item: V, item_expiration_ns: int): - self._cache[key] = CacheItem(item, time.perf_counter_ns() + item_expiration_ns) - self._cleanup() + with self._lock: + self._cache[key] = CacheItem(item, time.perf_counter_ns() + item_expiration_ns) + self._cleanup() def remove(self, key: K): - self._cache.pop(key, None) - self._cleanup() + with self._lock: + self._cache.pop(key, None) + self._cleanup() def clear(self): - self._cache.clear() + with self._lock: + self._cache.clear() def get_dict(self) -> Dict[K, V]: with self._lock: diff --git a/aws_advanced_python_wrapper/utils/concurrent.py b/aws_advanced_python_wrapper/utils/concurrent.py index 679933a0..a209810e 100644 --- a/aws_advanced_python_wrapper/utils/concurrent.py +++ b/aws_advanced_python_wrapper/utils/concurrent.py @@ -28,22 +28,31 @@ def __init__(self): self._lock = Lock() def __len__(self): - return len(self._dict) + with self._lock: + return len(self._dict) def __contains__(self, key): - return key in self._dict + with self._lock: + return key in self._dict def __str__(self): - return f"ConcurrentDict{str(self._dict)}" + with self._lock: + return f"ConcurrentDict{str(self._dict)}" def __repr__(self): - return f"ConcurrentDict{str(self._dict)}" + with self._lock: + return f"ConcurrentDict{str(self._dict)}" def get(self, key: K, default_value: Optional[V] = None) -> Optional[V]: - return self._dict.get(key, default_value) + with self._lock: + return self._dict.get(key, default_value) - def clear(self): - self._dict.clear() + def clear(self, disposal_func: Optional[Callable] = None): + with self._lock: + if disposal_func is not None: + for key, value in self._dict.items(): + disposal_func(key, value) + self._dict.clear() def compute_if_present(self, key: K, remapping_func: Callable) -> Optional[V]: with self._lock: diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py index 4085e43c..8033362e 100644 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +++ b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py @@ -53,19 +53,54 @@ def items(self) -> List[Tuple[K, CacheItem[V]]]: return self._cdict.items() def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: - self._cleanup() + self.cleanup() + cache_item = self._cdict.compute_if_absent( + key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) + return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item + + def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: + self.cleanup() + self._remove_if_disposable(key) cache_item = self._cdict.compute_if_absent( key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item def get(self, key: K) -> Optional[V]: - self._cleanup() + self.cleanup() cache_item = self._cdict.get(key) return cache_item.item if cache_item is not None else None def remove(self, key: K): self._remove_and_dispose(key) - self._cleanup() + self.cleanup() + + def clear(self): + if self._item_disposal_func is not None: + # Dispose all items atomically + self._cdict.clear(lambda k, cache_item: self._item_disposal_func(cache_item.item)) + else: + # Just clear without disposal + self._cdict.clear() + + def cleanup(self): + current_time = perf_counter_ns() + if self._cleanup_time_ns.get() > current_time: + return + + self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) + keys = self._cdict.keys() + for key in keys: + self._remove_if_expired(key) + + def _remove_if_disposable(self, key: K): + def _remove_if_disposable_internal(_, cache_item): + if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item): + if self._item_disposal_func is not None: + self._item_disposal_func(cache_item.item) + return None + return cache_item + + self._cdict.compute_if_present(key, _remove_if_disposable_internal) def _remove_and_dispose(self, key: K): cache_item = self._cdict.remove(key) @@ -88,25 +123,6 @@ def _should_cleanup_item(self, cache_item: CacheItem) -> bool: return perf_counter_ns() > cache_item.expiration_time and self._should_dispose_func(cache_item.item) return perf_counter_ns() > cache_item.expiration_time - def clear(self): - # Dispose all items while holding the lock - if self._item_disposal_func is not None: - self._cdict.apply_if( - lambda k, v: True, # Apply to all items - lambda k, cache_item: self._item_disposal_func(cache_item.item) - ) - self._cdict.clear() - - def _cleanup(self): - current_time = perf_counter_ns() - if self._cleanup_time_ns.get() > current_time: - return - - self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) - keys = self._cdict.keys() - for key in keys: - self._remove_if_expired(key) - class SlidingExpirationCacheWithCleanupThread(SlidingExpirationCache, Generic[K, V]): def __init__( @@ -118,39 +134,16 @@ def __init__( self._cleanup_thread = Thread(target=self._cleanup_thread_internal, daemon=True) self._cleanup_thread.start() - def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: - self._remove_if_disposable(key) - cache_item = self._cdict.compute_if_absent( - key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) - return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item - - def _remove_if_disposable(self, key: K): - def _remove_if_disposable_internal(_, cache_item): - if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item): - if self._item_disposal_func is not None: - self._item_disposal_func(cache_item.item) - return None - return cache_item - - self._cdict.compute_if_present(key, _remove_if_disposable_internal) - def _cleanup_thread_internal(self): while True: try: sleep(self._cleanup_interval_ns / 1_000_000_000) - self._cleanup_time_ns.set(perf_counter_ns() + self._cleanup_interval_ns) - keys = self._cdict.keys() - for key in keys: - try: - self._remove_if_expired(key) - except Exception: - pass # ignore + # Force cleanup by resetting the interval timer + self._cleanup_time_ns.set(0) + self.cleanup() except Exception: break - def _cleanup(self): - pass # cleanup thread handles this - class CacheItem(Generic[V]): def __init__(self, item: V, expiration_time: int): diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py new file mode 100644 index 00000000..0f9f1e71 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/sliding_expiration_cache_container.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from threading import Event, Thread +from typing import Callable, ClassVar, Dict, Optional + +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCache + +logger = Logger(__name__) + + +class SlidingExpirationCacheContainer: + """ + A container class for managing multiple named sliding expiration caches. + Provides static methods for getting, creating, and releasing caches. + + This container manages SlidingExpirationCache instances and provides a single + cleanup thread that periodically cleans up all managed caches. + """ + + _caches: ClassVar[Dict[str, SlidingExpirationCache]] = {} + _lock: ClassVar[threading.Lock] = threading.Lock() + _cleanup_thread: ClassVar[Optional[Thread]] = None + _cleanup_interval_ns: ClassVar[int] = 300_000_000_000 # 5 minutes default + _is_stopped: ClassVar[Event] = Event() + + @classmethod + def get_or_create_cache( + cls, + name: str, + cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes + should_dispose_func: Optional[Callable] = None, + item_disposal_func: Optional[Callable] = None + ) -> SlidingExpirationCache: + """ + Get an existing cache or create a new one if it doesn't exist. + + The cleanup thread is started lazily when the first cache is created. + + Args: + name: Unique identifier for the cache + cleanup_interval_ns: Cleanup interval in nanoseconds (only used when creating new cache) + should_dispose_func: Optional function to determine if item should be disposed + item_disposal_func: Optional function to dispose items + + Returns: + SlidingExpirationCache instance + """ + with cls._lock: + if name not in cls._caches: + cls._caches[name] = SlidingExpirationCache( + cleanup_interval_ns=cleanup_interval_ns, + should_dispose_func=should_dispose_func, + item_disposal_func=item_disposal_func + ) + + # Start cleanup thread if not already running + if cls._cleanup_thread is None or not cls._cleanup_thread.is_alive(): + cls._is_stopped.clear() + cls._cleanup_thread = Thread( + target=cls._cleanup_thread_internal, + daemon=True, + name="SlidingExpirationCacheContainer-Cleanup" + ) + cls._cleanup_thread.start() + + return cls._caches[name] + + @classmethod + def release_resources(cls) -> None: + """ + Clear all caches and stop the cleanup thread. + This will dispose all cached items and release all resources. + """ + with cls._lock: + # Stop the cleanup thread + cls._is_stopped.set() + + # Clear all caches (will dispose items if disposal function is set) + for name, cache in cls._caches.items(): + try: + cache.clear() + except Exception as e: + logger.warning("SlidingExpirationCacheContainer.ErrorReleasingCache", name, e) + + cls._caches.clear() + + # Wait for cleanup thread to stop (outside the lock) + if cls._cleanup_thread is not None and cls._cleanup_thread.is_alive(): + cls._cleanup_thread.join(timeout=2.0) + cls._cleanup_thread = None + + @classmethod + def _cleanup_thread_internal(cls) -> None: + while not cls._is_stopped.is_set(): + # Wait for the cleanup interval or until stopped + if cls._is_stopped.wait(timeout=cls._cleanup_interval_ns / 1_000_000_000): + break + + # Cleanup all caches + with cls._lock: + cache_items = list(cls._caches.items()) + + for name, cache in cache_items: + try: + cache.cleanup() + except Exception as e: + logger.debug("SlidingExpirationCacheContainer.ErrorDuringCleanup", name, e) diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index e3c7d2f9..808ea48a 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -21,14 +21,13 @@ from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager -from aws_advanced_python_wrapper.custom_endpoint_plugin import ( - CustomEndpointMonitor, CustomEndpointPlugin) +from aws_advanced_python_wrapper.custom_endpoint_plugin import \ + CustomEndpointMonitor from aws_advanced_python_wrapper.database_dialect import DatabaseDialectManager from aws_advanced_python_wrapper.driver_dialect_manager import \ DriverDialectManager from aws_advanced_python_wrapper.exception_handling import ExceptionManager -from aws_advanced_python_wrapper.host_list_provider import ( - MonitoringRdsHostListProvider, RdsHostListProvider) +from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider from aws_advanced_python_wrapper.host_monitoring_plugin import \ MonitoringThreadContainer from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl @@ -36,6 +35,8 @@ ThreadPoolContainer from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer if TYPE_CHECKING: from .utils.test_driver import TestDriver @@ -145,11 +146,10 @@ def pytest_runtest_setup(item): RdsHostListProvider._cluster_ids_to_update.clear() PluginServiceImpl._host_availability_expiring_cache.clear() DatabaseDialectManager._known_endpoint_dialects.clear() - CustomEndpointPlugin._monitors.clear() CustomEndpointMonitor._custom_endpoint_info_cache.clear() MonitoringThreadContainer.clean_up() ThreadPoolContainer.release_resources(wait=True) - MonitoringRdsHostListProvider._monitors.clear() + SlidingExpirationCacheContainer.release_resources() ConnectionProviderManager.release_resources() ConnectionProviderManager.reset_provider() diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 8fffe062..f77f121f 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -22,8 +22,6 @@ from aws_advanced_python_wrapper.errors import ( FailoverSuccessError, TransactionResolutionUnknownError) -from aws_advanced_python_wrapper.host_list_provider import \ - MonitoringRdsHostListProvider from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from .utils.conditions import (disable_on_features, enable_on_deployments, @@ -61,7 +59,6 @@ def setup_method(self, request): yield # Clean up global resources created by wrapper release_resources() - MonitoringRdsHostListProvider.release_resources() self.logger.info(f"Ending test: {request.node.name}") release_resources() gc.collect() diff --git a/tests/unit/test_limitless_router_service.py b/tests/unit/test_limitless_router_service.py index 53ba45ab..9838565c 100644 --- a/tests/unit/test_limitless_router_service.py +++ b/tests/unit/test_limitless_router_service.py @@ -20,11 +20,18 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer CLUSTER_ID: str = "some_cluster_id" EXPIRATION_NANO_SECONDS: int = 60 * 60 * 1_000_000_000 +def get_router_cache(): + """Helper to get the limitless router cache from the container.""" + return SlidingExpirationCacheContainer.get_or_create_cache("limitless_router_cache") + + @pytest.fixture def writer_host(): return HostInfo("instance-0", 5432, HostRole.WRITER, HostAvailability.AVAILABLE) @@ -136,8 +143,8 @@ def run_before_and_after_tests(mock_limitless_router_service): yield # After - - LimitlessRouterService._limitless_router_cache.clear() + # Clear the cache through the container + get_router_cache().clear() def test_establish_connection_empty_routers_list_then_wait_for_router_info_then_raises_exception(mocker, @@ -202,8 +209,8 @@ def test_establish_connection_host_info_in_router_cache_then_call_connection_fun props, mock_plugin_service, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_connect_func = mocker.MagicMock() mock_connect_func.return_value = mock_conn @@ -251,7 +258,7 @@ def test_establish_connection_fetch_router_list_and_host_info_in_router_list_the limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() mock_connect_func.assert_called_once() @@ -265,8 +272,8 @@ def test_establish_connection_router_cache_then_select_host(mocker, plugin, limitless_router1, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.return_value = mock_conn @@ -287,7 +294,7 @@ def test_establish_connection_router_cache_then_select_host(mocker, limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", limitless_routers) @@ -326,7 +333,7 @@ def test_establish_connection_fetch_router_list_then_select_host(mocker, limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) mock_limitless_query_helper.query_for_limitless_routers.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "weighted_random", @@ -345,8 +352,8 @@ def test_establish_connection_host_info_in_router_cache_can_call_connection_func plugin, limitless_router1, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.return_value = mock_conn @@ -367,7 +374,7 @@ def test_establish_connection_host_info_in_router_cache_can_call_connection_func limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) mock_plugin_service.get_host_info_by_strategy.assert_called_once() mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -385,8 +392,8 @@ def test_establish_connection_selected_host_raises_exception_and_retries(mocker, plugin, limitless_router1, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ Exception(), limitless_router1 @@ -410,7 +417,7 @@ def test_establish_connection_selected_host_raises_exception_and_retries(mocker, limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -429,8 +436,8 @@ def test_establish_connection_selected_host_none_then_retry(mocker, plugin, limitless_router1, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ None, limitless_router1 @@ -454,7 +461,7 @@ def test_establish_connection_selected_host_none_then_retry(mocker, limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -474,8 +481,8 @@ def test_establish_connection_plugin_service_connect_raises_exception_then_retry limitless_router1, limitless_router2, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.side_effect = [ limitless_router1, limitless_router2 @@ -502,7 +509,7 @@ def test_establish_connection_plugin_service_connect_raises_exception_then_retry limitless_router_service.establish_connection(input_context) assert mock_conn == input_context.get_connection() - assert limitless_routers == LimitlessRouterService._limitless_router_cache.get(CLUSTER_ID) + assert limitless_routers == get_router_cache().get(CLUSTER_ID) assert mock_plugin_service.get_host_info_by_strategy.call_count == 2 mock_plugin_service.get_host_info_by_strategy.assert_called_with(HostRole.WRITER, "highest_weight", limitless_routers) @@ -521,8 +528,8 @@ def test_establish_connection_retry_and_max_retries_exceeded_then_raise_exceptio plugin, limitless_router1, limitless_routers): - LimitlessRouterService._limitless_router_cache.compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, - EXPIRATION_NANO_SECONDS) + get_router_cache().compute_if_absent(CLUSTER_ID, lambda _: limitless_routers, + EXPIRATION_NANO_SECONDS) mock_plugin_service.get_host_info_by_strategy.return_value = limitless_router1 mock_plugin_service.connect.side_effect = Exception() diff --git a/tests/unit/test_monitor_service.py b/tests/unit/test_monitor_service.py index 6ce28f2a..b0620966 100644 --- a/tests/unit/test_monitor_service.py +++ b/tests/unit/test_monitor_service.py @@ -96,9 +96,9 @@ def test_start_monitoring( def test_start_monitoring__multiple_calls(monitor_service_with_container, mock_monitor, mock_conn, mocker): aliases = frozenset({"instance-1"}) - # Mock ThreadPoolContainer.get_thread_pool + # Mock the _thread_pool directly on the container instance since it's now cached in __init__ mock_thread_pool = mocker.MagicMock() - mocker.patch('aws_advanced_python_wrapper.host_monitoring_plugin.ThreadPoolContainer.get_thread_pool', return_value=mock_thread_pool) + monitor_service_with_container._monitor_container._thread_pool = mock_thread_pool num_calls = 5 for _ in range(num_calls): diff --git a/tests/unit/test_monitoring_thread_container.py b/tests/unit/test_monitoring_thread_container.py index 4e3dac25..1a9469d1 100644 --- a/tests/unit/test_monitoring_thread_container.py +++ b/tests/unit/test_monitoring_thread_container.py @@ -69,7 +69,8 @@ def test_get_or_create_monitor__monitor_created( container, mock_monitor_supplier, mock_stopped_monitor, mock_monitor1, mock_future, mocker): mock_thread_pool = mocker.MagicMock() mock_thread_pool.submit.return_value = mock_future - mocker.patch('aws_advanced_python_wrapper.host_monitoring_plugin.ThreadPoolContainer.get_thread_pool', return_value=mock_thread_pool) + # Mock the _thread_pool directly on the container instance since it's now cached in __init__ + container._thread_pool = mock_thread_pool result = container.get_or_create_monitor(frozenset({"alias-1", "alias-2"}), mock_monitor_supplier) assert mock_monitor1 == result diff --git a/tests/unit/test_sliding_expiration_cache_container.py b/tests/unit/test_sliding_expiration_cache_container.py new file mode 100644 index 00000000..c21c5482 --- /dev/null +++ b/tests/unit/test_sliding_expiration_cache_container.py @@ -0,0 +1,231 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCache +from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ + SlidingExpirationCacheContainer + + +@pytest.fixture(autouse=True) +def cleanup_caches(): + """Clean up all caches after each test""" + yield + SlidingExpirationCacheContainer.release_resources() + + +def test_get_or_create_cache_creates_new_cache(): + cache = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + assert isinstance(cache, SlidingExpirationCache) + + +def test_get_or_create_cache_returns_existing_cache(): + cache1 = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + cache2 = SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + assert cache1 is cache2 + + +def test_get_or_create_cache_with_custom_cleanup_interval(): + cache = SlidingExpirationCacheContainer.get_or_create_cache( + "test_cache", + cleanup_interval_ns=5_000_000_000 # 5 seconds + ) + assert cache._cleanup_interval_ns == 5_000_000_000 + + +def test_get_or_create_cache_with_disposal_functions(): + disposed_items = [] + + def should_dispose(item): + return item > 10 + + def dispose(item): + disposed_items.append(item) + + cache = SlidingExpirationCacheContainer.get_or_create_cache( + "test_cache", + should_dispose_func=should_dispose, + item_disposal_func=dispose + ) + + assert cache._should_dispose_func is should_dispose + assert cache._item_disposal_func is dispose + + +def test_multiple_caches_are_independent(): + cache1 = SlidingExpirationCacheContainer.get_or_create_cache("cache1") + cache2 = SlidingExpirationCacheContainer.get_or_create_cache("cache2") + + cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) + cache2.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) + + assert cache1.get("key1") == "value1" + assert cache1.get("key2") is None + assert cache2.get("key2") == "value2" + assert cache2.get("key1") is None + + +def test_cleanup_thread_starts_on_first_cache(): + # Cleanup thread should start when first cache is created + SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + + # Check that cleanup thread is running + assert SlidingExpirationCacheContainer._cleanup_thread is not None + assert SlidingExpirationCacheContainer._cleanup_thread.is_alive() + + +def test_release_resources_clears_all_caches(): + cache1 = SlidingExpirationCacheContainer.get_or_create_cache("cache1") + cache2 = SlidingExpirationCacheContainer.get_or_create_cache("cache2") + + cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) + cache2.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) + + SlidingExpirationCacheContainer.release_resources() + + # Caches should be cleared + assert len(SlidingExpirationCacheContainer._caches) == 0 + + +def test_release_resources_stops_cleanup_thread(): + SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + + cleanup_thread = SlidingExpirationCacheContainer._cleanup_thread + assert cleanup_thread is not None + assert cleanup_thread.is_alive() + + SlidingExpirationCacheContainer.release_resources() + + # Give thread time to stop + time.sleep(0.1) + + # Thread should be stopped + assert not cleanup_thread.is_alive() + + +def test_release_resources_disposes_items(): + disposed_items = [] + + def dispose(item): + disposed_items.append(item) + + cache = SlidingExpirationCacheContainer.get_or_create_cache( + "test_cache", + item_disposal_func=dispose + ) + + cache.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) + cache.compute_if_absent("key2", lambda k: "value2", 1_000_000_000) + + SlidingExpirationCacheContainer.release_resources() + + # Items should have been disposed + assert "value1" in disposed_items + assert "value2" in disposed_items + + +def test_cleanup_thread_cleans_expired_items(): + # Use very short intervals for testing + cache = SlidingExpirationCacheContainer.get_or_create_cache( + "test_cache", + cleanup_interval_ns=100_000_000 # 0.1 seconds + ) + + # Add item with very short expiration + cache.compute_if_absent("key1", lambda k: "value1", 50_000_000) # 0.05 seconds + + assert cache.get("key1") == "value1" + + # Wait for item to expire and cleanup to run + time.sleep(0.3) + + # Item should be cleaned up + assert cache.get("key1") is None + + +def test_same_cache_name_returns_same_instance_across_calls(): + cache1 = SlidingExpirationCacheContainer.get_or_create_cache("shared_cache") + cache1.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) + + # Get the same cache again + cache2 = SlidingExpirationCacheContainer.get_or_create_cache("shared_cache") + + # Should be the same instance with the same data + assert cache1 is cache2 + assert cache2.get("key1") == "value1" + + +def test_cleanup_thread_handles_multiple_caches(): + cache1 = SlidingExpirationCacheContainer.get_or_create_cache( + "cache1", + cleanup_interval_ns=100_000_000 # 0.1 seconds + ) + cache2 = SlidingExpirationCacheContainer.get_or_create_cache( + "cache2", + cleanup_interval_ns=100_000_000 # 0.1 seconds + ) + + # Add items with short expiration + cache1.compute_if_absent("key1", lambda k: "value1", 50_000_000) + cache2.compute_if_absent("key2", lambda k: "value2", 50_000_000) + + assert cache1.get("key1") == "value1" + assert cache2.get("key2") == "value2" + + # Wait for cleanup + time.sleep(0.3) + + # Both should be cleaned up + assert cache1.get("key1") is None + assert cache2.get("key2") is None + + +def test_release_resources_handles_disposal_errors(): + def failing_dispose(item): + raise Exception("Disposal failed") + + cache = SlidingExpirationCacheContainer.get_or_create_cache( + "test_cache", + item_disposal_func=failing_dispose + ) + + cache.compute_if_absent("key1", lambda k: "value1", 1_000_000_000) + + # Should not raise exception even if disposal fails + SlidingExpirationCacheContainer.release_resources() + + # Cache should still be cleared + assert len(SlidingExpirationCacheContainer._caches) == 0 + + +def test_cleanup_thread_respects_is_stopped_event(): + # Clear the stop event first in case it was set by a previous test + SlidingExpirationCacheContainer._is_stopped.clear() + + SlidingExpirationCacheContainer.get_or_create_cache("test_cache") + + cleanup_thread = SlidingExpirationCacheContainer._cleanup_thread + assert cleanup_thread is not None + assert cleanup_thread.is_alive() + + # Set the stop event + SlidingExpirationCacheContainer._is_stopped.set() + + # Thread should stop quickly (not wait for full cleanup interval) + cleanup_thread.join(timeout=1.0) + assert not cleanup_thread.is_alive()