diff --git a/app/ioc.py b/app/ioc.py index 2c733e6a8..27c1e3a2a 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -65,6 +65,7 @@ from ldap_protocol.kerberos.service import KerberosService from ldap_protocol.kerberos.template_render import KRBTemplateRenderer from ldap_protocol.ldap_requests.contexts import ( + AsyncSessionSearchRequest, LDAPAddRequestContext, LDAPBindRequestContext, LDAPDeleteRequestContext, @@ -205,7 +206,7 @@ async def get_kadmin_http( yield KadminHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_kadmin( + def get_kadmin( self, client: KadminHTTPClient, kadmin_class: type[AbstractKadmin], @@ -260,14 +261,14 @@ async def get_dns_http_client( yield DNSManagerHTTPClient(client) @provide(scope=Scope.REQUEST) - async def get_dns_mngr( + def get_dns_mngr( self, settings: DNSManagerSettings, dns_manager_class: type[AbstractDNSManager], http_client: DNSManagerHTTPClient, - ) -> AsyncIterator[AbstractDNSManager]: + ) -> AbstractDNSManager: """Get DNSManager class.""" - yield dns_manager_class(settings=settings, http_client=http_client) + return dns_manager_class(settings=settings, http_client=http_client) @provide(scope=Scope.APP) async def get_redis_for_sessions( @@ -284,7 +285,7 @@ async def get_redis_for_sessions( await client.aclose() @provide(scope=Scope.APP) - async def get_session_storage( + def get_session_storage( self, client: SessionStorageClient, settings: Settings, @@ -297,7 +298,7 @@ async def get_session_storage( ) @provide() - async def get_normalized_audit_event( + def get_normalized_audit_event( self, ) -> type[NormalizedAuditEvent]: """Get normalized audit event class.""" @@ -318,13 +319,13 @@ async def get_audit_redis_client( await client.aclose() @provide(scope=Scope.APP) - async def get_raw_audit_manager( + def get_raw_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[RawAuditManager]: + ) -> RawAuditManager: """Get raw audit manager.""" - yield RawAuditManager( + return RawAuditManager( client, settings.RAW_EVENT_STREAM_NAME, settings.EVENT_HANDLER_GROUP, @@ -333,13 +334,13 @@ async def get_raw_audit_manager( ) @provide(scope=Scope.APP) - async def get_normalized_audit_manager( + def get_normalized_audit_manager( self, client: AuditRedisClient, settings: Settings, - ) -> AsyncIterator[NormalizedAuditManager]: + ) -> NormalizedAuditManager: """Get raw audit manager.""" - yield NormalizedAuditManager( + return NormalizedAuditManager( client, settings.NORMALIZED_EVENT_STREAM_NAME, settings.EVENT_SENDER_GROUP, @@ -352,7 +353,7 @@ async def get_normalized_audit_manager( audit_destination_dao = provide(AuditDestinationDAO, scope=Scope.REQUEST) @provide(scope=Scope.REQUEST) - async def get_dhcp_manager_repository( + def get_dhcp_manager_repository( self, session: AsyncSession, ) -> DHCPManagerRepository: @@ -368,20 +369,20 @@ async def get_dhcp_manager_state( return await dhcp_manager_repository.ensure_state() @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr_class( + def get_dhcp_mngr_class( self, dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get DHCP manager type.""" - return await get_dhcp_manager_class(dhcp_state) + return get_dhcp_manager_class(dhcp_state) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository_class( + def get_dhcp_api_repository_class( self, dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get DHCP API repository type.""" - return await get_dhcp_api_repository_class(dhcp_state) + return get_dhcp_api_repository_class(dhcp_state) @provide(scope=Scope.APP) async def get_dhcp_http_client( @@ -395,7 +396,7 @@ async def get_dhcp_http_client( yield DHCPManagerHTTPClient(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_api_repository( + def get_dhcp_api_repository( self, http_client: DHCPManagerHTTPClient, dhcp_api_repository_class: type[DHCPAPIRepository], @@ -404,7 +405,7 @@ async def get_dhcp_api_repository( return dhcp_api_repository_class(http_client) @provide(scope=Scope.REQUEST) - async def get_dhcp_mngr( + def get_dhcp_mngr( self, dhcp_manager_class: type[AbstractDHCPManager], dhcp_api_repository: DHCPAPIRepository, @@ -445,7 +446,7 @@ async def get_dhcp_mngr( ) password_utils = provide(PasswordUtils, scope=Scope.RUNTIME) - access_manager = provide(AccessManager, scope=Scope.REQUEST) + access_manager = provide(AccessManager, scope=Scope.RUNTIME) role_dao = provide(RoleDAO, scope=Scope.REQUEST) ace_dao = provide(AccessControlEntryDAO, scope=Scope.REQUEST) role_use_case = provide(RoleUseCase, scope=Scope.REQUEST) @@ -490,15 +491,37 @@ class LDAPContextProvider(Provider): LDAPModifyDNRequestContext, scope=Scope.REQUEST, ) - search_request_context = provide( - LDAPSearchRequestContext, - scope=Scope.REQUEST, - ) unbind_request_context = provide( LDAPUnbindRequestContext, scope=Scope.REQUEST, ) + @provide(scope=Scope.SESSION) + async def create_search_session( + self, + async_session: async_sessionmaker[AsyncSession], + ) -> AsyncIterator[AsyncSessionSearchRequest]: + """Create session for request.""" + async with async_session() as session: + yield session # type: ignore + + @provide(scope=Scope.SESSION, provides=LDAPSearchRequestContext) + def get_search_request_context( + self, + session: AsyncSessionSearchRequest, + ldap_session: LDAPSession, + settings: Settings, + access_manager: AccessManager, + ) -> LDAPSearchRequestContext: + """Get search request context.""" + return LDAPSearchRequestContext( + session=session, + ldap_session=ldap_session, + settings=settings, + access_manager=access_manager, + rootdse_rd=RootDSEReader(settings, SADomainGateway(session)), + ) + class HTTPProvider(LDAPContextProvider): """HTTP LDAP session.""" @@ -508,7 +531,7 @@ class HTTPProvider(LDAPContextProvider): monitor_use_case = provide(AuditMonitorUseCase, scope=Scope.REQUEST) @provide() - async def get_audit_monitor( + def get_audit_monitor( self, session: AsyncSession, audit_use_case: "AuditUseCase", @@ -568,7 +591,7 @@ def get_permissions_provider( return auth_provider @provide() - async def get_identity_provider( + def get_identity_provider( self, request: Request, session_storage: SessionStorage, @@ -649,6 +672,23 @@ def get_krb_template_render( ) network_policy_gateway = provide(NetworkPolicyGateway, scope=Scope.REQUEST) + @provide(scope=Scope.REQUEST, provides=LDAPSearchRequestContext) + async def get_search_request_context( + self, + session: AsyncSession, + ldap_session: LDAPSession, + settings: Settings, + access_manager: AccessManager, + ) -> LDAPSearchRequestContext: + """Get search request context.""" + return LDAPSearchRequestContext( + session=session, # type: ignore + ldap_session=ldap_session, + settings=settings, + access_manager=access_manager, + rootdse_rd=RootDSEReader(settings, SADomainGateway(session)), + ) + class LDAPServerProvider(LDAPContextProvider): """Provider with session scope.""" @@ -739,7 +779,7 @@ async def get_client( yield MFAHTTPClient(client) @provide(provides=MultifactorAPI) - async def get_http_mfa( + def get_http_mfa( self, credentials: MFA_HTTP_Creds, client: MFAHTTPClient, @@ -761,7 +801,7 @@ async def get_http_mfa( ) @provide(provides=LDAPMultiFactorAPI) - async def get_ldap_mfa( + def get_ldap_mfa( self, credentials: MFA_LDAP_Creds, client: MFAHTTPClient, diff --git a/app/ldap_protocol/dhcp/__init__.py b/app/ldap_protocol/dhcp/__init__.py index 27df7d0c0..cf26f1903 100644 --- a/app/ldap_protocol/dhcp/__init__.py +++ b/app/ldap_protocol/dhcp/__init__.py @@ -26,7 +26,7 @@ from .stub import StubDHCPAPIRepository, StubDHCPManager -async def get_dhcp_manager_class( +def get_dhcp_manager_class( dhcp_state: DHCPManagerState, ) -> type[AbstractDHCPManager]: """Get an instance of the DHCP manager.""" @@ -35,7 +35,7 @@ async def get_dhcp_manager_class( return StubDHCPManager -async def get_dhcp_api_repository_class( +def get_dhcp_api_repository_class( dhcp_state: DHCPManagerState, ) -> type[DHCPAPIRepository]: """Get an instance of the DHCP API repository.""" diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index 3facb0562..fd7c6a601 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -27,7 +27,7 @@ def from_data( """Create structure from ASN1Row dataclass list.""" return cls(message_id=1) - async def handle(self) -> AsyncGenerator: + async def handle(self, ctx: None) -> AsyncGenerator: # noqa: ARG002 """Handle message with current user.""" await asyncio.sleep(0) return diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 4e15bc75d..6940c99c6 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -65,6 +65,7 @@ class AddRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.ADD + CONTEXT_TYPE: ClassVar[type] = LDAPAddRequestContext entry: str = Field(..., description="Any `DistinguishedName`") attributes: list[PartialAttribute] diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index 3123e6247..b4a480aa0 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -24,6 +24,7 @@ from ldap_protocol.dependency import resolve_deps from ldap_protocol.dialogue import LDAPSession from ldap_protocol.ldap_responses import BaseResponse, LDAPResult +from ldap_protocol.objects import ProtocolRequests from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.audit.events.factory import ( RawAuditEventBuilderRedis, @@ -62,6 +63,7 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" + CONTEXT_TYPE: ClassVar[type] handle: ClassVar[handler] from_data: ClassVar[serializer] __event_data: dict = {} @@ -113,38 +115,42 @@ async def handle_tcp( container: AsyncContainer, ) -> AsyncIterator[BaseResponse]: """Hanlde response with tcp.""" - kwargs = await resolve_deps(func=self.handle, container=container) - responses = [] + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore + else: + ctx = None - async for response in self.handle(**kwargs): + responses = [] + async for response in self.handle(ctx=ctx): responses.append(response) yield response - ldap_session = await container.get(LDAPSession) - settings = await container.get(Settings) - audit_use_case = await container.get(AuditUseCase) - - if await audit_use_case.check_event_processing_enabled( - self.PROTOCOL_OP, - ): - username = getattr( - ldap_session.user, - "user_principal_name", - "ANONYMOUS", - ) - event = RawAuditEventBuilderRedis.from_ldap_request( - self, - responses=responses, - username=username, - ip=ldap_session.ip, - protocol="TCP_LDAP", - settings=settings, - context=self.get_event_data(), - ) + if self.PROTOCOL_OP != ProtocolRequests.SEARCH: + ldap_session = await container.get(LDAPSession) + settings = await container.get(Settings) + audit_use_case = await container.get(AuditUseCase) + + if await audit_use_case.check_event_processing_enabled( + self.PROTOCOL_OP, + ): + username = getattr( + ldap_session.user, + "user_principal_name", + "ANONYMOUS", + ) + event = RawAuditEventBuilderRedis.from_ldap_request( + self, + responses=responses, + username=username, + ip=ldap_session.ip, + protocol="TCP_LDAP", + settings=settings, + context=self.get_event_data(), + ) - ldap_session.event_task_group.create_task( - audit_use_case.manager.send_event(event), - ) + ldap_session.event_task_group.create_task( + audit_use_case.manager.send_event(event), + ) async def _handle_api( self, @@ -156,7 +162,11 @@ async def _handle_api( :param AsyncSession session: db session :return list[BaseResponse]: list of handled responses """ - kwargs = await resolve_deps(func=self.handle, container=container) + if self.PROTOCOL_OP != ProtocolRequests.ABANDON: + ctx = await container.get(self.CONTEXT_TYPE) # type: ignore + else: + ctx = None + ldap_session = await container.get(LDAPSession) settings = await container.get(Settings) audit_use_case = await container.get(AuditUseCase) @@ -168,7 +178,7 @@ async def _handle_api( else: log_api.info(f"{get_class_name(self)}[{un}]") - responses = [response async for response in self.handle(**kwargs)] + responses = [response async for response in self.handle(ctx=ctx)] if settings.DEBUG: for response in responses: diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index d64294e72..1637da5c9 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -49,6 +49,7 @@ class BindRequest(BaseRequest): """Bind request fields mapping.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.BIND + CONTEXT_TYPE: ClassVar[type] = LDAPBindRequestContext version: int name: str @@ -240,6 +241,7 @@ class UnbindRequest(BaseRequest): """Remove user from ldap_session.""" PROTOCOL_OP: ClassVar[int] = ProtocolRequests.UNBIND + CONTEXT_TYPE: ClassVar[type] = LDAPUnbindRequestContext @classmethod def from_data( diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index df4918f8d..041657c9a 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -5,6 +5,7 @@ """ from dataclasses import dataclass +from typing import NewType from sqlalchemy.ext.asyncio import AsyncSession @@ -20,6 +21,8 @@ from ldap_protocol.session_storage import SessionStorage from password_utils import PasswordUtils +AsyncSessionSearchRequest = NewType("AsyncSessionSearchRequest", AsyncSession) + @dataclass class LDAPAddRequestContext: @@ -67,7 +70,7 @@ class LDAPBindRequestContext: class LDAPSearchRequestContext: """Context for LDAP search request.""" - session: AsyncSession + session: AsyncSessionSearchRequest ldap_session: LDAPSession settings: Settings access_manager: AccessManager diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index 5c731e9b7..334df621a 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -43,6 +43,7 @@ class DeleteRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.DELETE + CONTEXT_TYPE: ClassVar[type] = LDAPDeleteRequestContext entry: str diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index 85ca1f31b..c3967889e 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -308,6 +308,7 @@ class ExtendedRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.EXTENDED + CONTEXT_TYPE: ClassVar[type] = LDAPExtendedRequestContext request_name: LDAPOID request_value: SerializeAsAny[BaseExtendedValue] diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 5161ae754..1603b3740 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -103,6 +103,7 @@ class ModifyRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY + CONTEXT_TYPE: ClassVar[type] = LDAPModifyRequestContext object: str changes: list[Changes] diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index d17120540..b251e160d 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -68,6 +68,7 @@ class ModifyDNRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.MODIFY_DN + CONTEXT_TYPE: ClassVar[type] = LDAPModifyDNRequestContext entry: str newrdn: str diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 01ec77169..1f9579dc2 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -14,7 +14,12 @@ from pydantic import Field, PrivateAttr, field_serializer from sqlalchemy import func, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, selectinload, with_loader_criteria +from sqlalchemy.orm import ( + contains_eager, + joinedload, + selectinload, + with_loader_criteria, +) from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select @@ -100,6 +105,7 @@ class SearchRequest(BaseRequest): """ PROTOCOL_OP: ClassVar[int] = ProtocolRequests.SEARCH + CONTEXT_TYPE: ClassVar[type] = LDAPSearchRequestContext base_object: str = Field("", description="Any `DistinguishedName`") scope: Scope @@ -339,7 +345,7 @@ def _mutate_query_with_attributes_to_load( if self.entity_type_name: query = ( query.join(qa(Directory.entity_type)) - .options(selectinload(qa(Directory.entity_type))) + .options(contains_eager(qa(Directory.entity_type))) ) # fmt: skip if self.all_attrs: @@ -369,8 +375,8 @@ def _build_query( query = ( select(Directory) .join(qa(Directory.user), isouter=True) - .options(joinedload(qa(Directory.user))) - .options(selectinload(qa(Directory.group))) + .options(contains_eager(qa(Directory.user))) + .options(joinedload(qa(Directory.group))) ) query = self._mutate_query_with_attributes_to_load(query) @@ -423,7 +429,7 @@ def _build_query( if self.member: query = query.options( - selectinload(qa(Directory.group)).selectinload( + joinedload(qa(Directory.group)).selectinload( qa(Group.members), ), ) @@ -501,7 +507,6 @@ async def _fill_attrs( ) if self.member_of: - logger.debug(f"Member of group: {directory.groups}") for group in directory.groups: attrs["memberOf"].append(group.directory.path_dn) @@ -541,9 +546,9 @@ async def tree_view( # noqa: C901 access_manager: AccessManager, ) -> AsyncGenerator[SearchResultEntry, None]: """Yield all resulted directories.""" - directories = await session.stream_scalars(query) + directories = await session.scalars(query) - async for directory in directories: + for directory in directories: attrs = defaultdict(list) obj_classes = [] diff --git a/tests/conftest.py b/tests/conftest.py index 58937dbd3..b27450f86 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -646,10 +646,25 @@ async def get_audit_monitor( LDAPModifyDNRequestContext, scope=Scope.REQUEST, ) - search_request_context = provide( - LDAPSearchRequestContext, - scope=Scope.REQUEST, - ) + + @provide(scope=Scope.REQUEST, provides=LDAPSearchRequestContext) + def get_search_request_context( + self, + session: AsyncSession, + ldap_session: LDAPSession, + settings: Settings, + access_manager: AccessManager, + rootdse_reader: RootDSEReader, + ) -> LDAPSearchRequestContext: + """Get search request context.""" + return LDAPSearchRequestContext( + session=session, # type: ignore + ldap_session=ldap_session, + settings=settings, + access_manager=access_manager, + rootdse_rd=rootdse_reader, + ) + unbind_request_context = provide( LDAPUnbindRequestContext, scope=Scope.REQUEST,