diff --git a/changelog.d/19640.misc b/changelog.d/19640.misc new file mode 100644 index 00000000000..f13d8f67530 --- /dev/null +++ b/changelog.d/19640.misc @@ -0,0 +1 @@ +Add a `FilteredEvent` class that saves us copying events. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 0614c805dad..67189e91e72 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -34,6 +34,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.handlers.admin import ExfiltrationWriter from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingDatabaseConnection @@ -150,14 +151,14 @@ def __init__(self, user_id: str, directory: str | None = None): if list(os.listdir(self.base_directory)): raise Exception("Directory must be empty") - def write_events(self, room_id: str, events: list[EventBase]) -> None: + def write_events(self, room_id: str, filtered_events: list[FilteredEvent]) -> None: room_directory = os.path.join(self.base_directory, "rooms", room_id) os.makedirs(room_directory, exist_ok=True) events_file = os.path.join(room_directory, "events") with open(events_file, "a") as f: - for event in events: - json.dump(event.get_pdu_json(), fp=f) + for filtered_event in filtered_events: + json.dump(filtered_event.event.get_pdu_json(), fp=f) def write_state( self, room_id: str, event_id: str, state: StateMap[EventBase] @@ -175,7 +176,7 @@ def write_state( def write_invite( self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: - self.write_events(room_id, [event]) + self.write_events(room_id, [FilteredEvent.state(event)]) # We write the invite state somewhere else as they aren't full events # and are only a subset of the state at the event. @@ -191,7 +192,7 @@ def write_invite( def write_knock( self, room_id: str, event: EventBase, state: StateMap[EventBase] ) -> None: - self.write_events(room_id, [event]) + self.write_events(room_id, [FilteredEvent.state(event)]) # We write the knock state somewhere else as they aren't full events # and are only a subset of the state at the event. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d4e9d50b966..66c962e17db 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -40,7 +40,7 @@ TransactionUnusedFallbackKeys, ) from synapse.events import EventBase -from synapse.events.utils import SerializeEventConfig +from synapse.events.utils import FilteredEvent, SerializeEventConfig from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL @@ -545,7 +545,7 @@ async def _serialize( ) -> list[JsonDict]: time_now = self.clock.time_msec() return await self._event_serializer.serialize_events( - list(events), + [FilteredEvent(event=e, membership=None) for e in events], time_now, config=SerializeEventConfig( as_client_event=True, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 76ebac8b17c..ff0476f5fbb 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -41,6 +41,7 @@ MAX_PDU_SIZE, EventContentFields, EventTypes, + EventUnsignedContentFields, RelationTypes, ) from synapse.api.errors import Codes, SynapseError @@ -416,6 +417,50 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: return d +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FilteredEvent: + """An event annotated with per-user data for client serialization. + + Produced by filter_and_transform_events_for_client. Carries the user's + membership at the time of the event so serialization can inject it into + unsigned.membership (MSC4115) without cloning the underlying event. + """ + + event: "EventBase" + """The event to be serialized.""" + + membership: str | None + """The user whose requesting the event's membership at the time of the + event was sent. + + This is None if we didn't compute the membership. In Synapse this happens a) + when returning state events to state endpoints, or b) when the event is + returned to an admin. + + According to the spec we don't have to include the membership for any events + if we don't want to, especially if its expensive to compute. In practice + clients really only care about events in the room timeline so that in + encrypted room they can determine if they should be able to decrypt the + event or not. + """ + + @classmethod + def state(cls, event: "EventBase") -> "FilteredEvent": + """Wrap a state event with no per-user membership annotation. + + The event must be a state event (i.e. have a state_key). + """ + assert event.is_state(), ( + f"FilteredEvent.state() called with non-state event {event.event_id}" + ) + return cls(event=event, membership=None) + + @classmethod + def admin_override(cls, event: "EventBase") -> "FilteredEvent": + """Wrap an event that bypasses visibility filtering due to admin privileges.""" + return cls(event=event, membership=None) + + @attr.s(slots=True, frozen=True, auto_attribs=True) class SerializeEventConfig: as_client_event: bool = True @@ -435,6 +480,9 @@ class SerializeEventConfig: # only server admins can see through other configuration. For example, # whether an event was soft failed by the server. include_admin_metadata: bool = False + # Whether MSC4354 (sticky events) is enabled. When True, the sticky TTL + # will be computed and included in the unsigned section of sticky events. + msc4354_enabled: bool = False @only_event_fields.validator def _validate_only_event_fields( @@ -461,6 +509,7 @@ def _serialize_event( time_now_ms: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + membership: str | None = None, ) -> JsonDict: """Serialize event for clients @@ -468,6 +517,8 @@ def _serialize_event( e time_now_ms config: Event serialization config + membership: The requesting user's membership at the time of the event, + to be injected into unsigned.membership (MSC4115). Returns: The serialized event dictionary. @@ -564,6 +615,23 @@ def _serialize_event( if e.internal_metadata.policy_server_spammy: d["unsigned"]["io.element.synapse.policy_server_spammy"] = True + if config.msc4354_enabled: + sticky_duration = e.sticky_duration() + if sticky_duration: + expires_at = ( + # min() ensures that the origin server can't lie about the time and + # send the event 'in the future', as that would allow them to exceed + # the 1 hour limit on stickiness duration. + min(e.origin_server_ts, time_now_ms) + sticky_duration.as_millis() + ) + if expires_at > time_now_ms: + d["unsigned"][EventUnsignedContentFields.STICKY_TTL] = ( + expires_at - time_now_ms + ) + + if membership is not None: + d["unsigned"][EventUnsignedContentFields.MEMBERSHIP] = membership + return d @@ -577,13 +645,15 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer") -> None: self._store = hs.get_datastores().main self._auth = hs.get_auth() + self._config = hs.config + self._clock = hs.get_clock() self._add_extra_fields_to_unsigned_client_event_callbacks: list[ ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK ] = [] async def serialize_event( self, - event: JsonDict | EventBase, + event: JsonDict | FilteredEvent, time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, @@ -605,7 +675,7 @@ async def serialize_event( The serialized event """ # To handle the case of presence events and the like - if not isinstance(event, EventBase): + if not isinstance(event, FilteredEvent): return event # Force-enable server admin metadata because the only time an event with @@ -617,11 +687,16 @@ async def serialize_event( ): config = make_config_for_admin(config) - serialized_event = _serialize_event(event, time_now, config=config) + if self._config.experimental.msc4354_enabled: + config = attr.evolve(config, msc4354_enabled=True) + + serialized_event = _serialize_event( + event.event, time_now, config=config, membership=event.membership + ) # If the event was redacted, fetch the redaction event from the database # and include it in the serialized event's unsigned section. - redacted_by: str | None = event.internal_metadata.redacted_by + redacted_by: str | None = event.event.internal_metadata.redacted_by if redacted_by is not None: serialized_event.setdefault("unsigned", {})["redacted_by"] = redacted_by if redaction_map is not None: @@ -648,7 +723,7 @@ async def serialize_event( new_unsigned = {} for callback in self._add_extra_fields_to_unsigned_client_event_callbacks: - u = await callback(event) + u = await callback(event.event) new_unsigned.update(u) if new_unsigned: @@ -666,9 +741,9 @@ async def serialize_event( # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: - if event.event_id in bundle_aggregations: + if event.event.event_id in bundle_aggregations: await self._inject_bundled_aggregations( - event, + event.event, time_now, config, bundle_aggregations, @@ -720,7 +795,7 @@ async def _inject_bundled_aggregations( # `sender` of the edit; however MSC3925 proposes extending it to the whole # of the edit, which is what we do here. serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event( - event_aggregations.replace, + FilteredEvent(event=event_aggregations.replace, membership=None), time_now, config=config, ) @@ -730,7 +805,7 @@ async def _inject_bundled_aggregations( thread = event_aggregations.thread serialized_latest_event = await self.serialize_event( - thread.latest_event, + FilteredEvent(event=thread.latest_event, membership=None), time_now, config=config, bundle_aggregations=bundled_aggregations, @@ -755,7 +830,7 @@ async def _inject_bundled_aggregations( @trace async def serialize_events( self, - events: Collection[JsonDict | EventBase], + events: Collection[JsonDict | FilteredEvent], time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, @@ -780,11 +855,13 @@ async def serialize_events( ) # Batch-fetch all redaction events in one go rather than one per event. - redaction_ids = { - e.internal_metadata.redacted_by - for e in events - if isinstance(e, EventBase) and e.internal_metadata.redacted_by is not None - } + redaction_ids: set[str] = set() + for e in events: + base = e.event if isinstance(e, FilteredEvent) else e + if isinstance(base, EventBase): + redacted_by = base.internal_metadata.redacted_by + if redacted_by is not None: + redaction_ids.add(redacted_by) redaction_map = ( await self._store.get_events(redaction_ids) if redaction_ids else {} ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2fb0e5814f2..d2c1f98d7c4 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -33,6 +33,7 @@ from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.types import ( JsonMapping, Requester, @@ -251,32 +252,40 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> topological=last_event.depth, ) - events = await filter_and_transform_events_for_client( + filtered_events = await filter_and_transform_events_for_client( self._storage_controllers, user_id, events, ) - writer.write_events(room_id, events) + writer.write_events(room_id, filtered_events) # Update the extremity tracking dicts - for event in events: + for filtered_event in filtered_events: # Check if we have any prev events that haven't been # processed yet, and add those to the appropriate dicts. - unseen_events = set(event.prev_event_ids()) - written_events + unseen_events = ( + set(filtered_event.event.prev_event_ids()) - written_events + ) if unseen_events: - event_to_unseen_prevs[event.event_id] = unseen_events + event_to_unseen_prevs[filtered_event.event.event_id] = ( + unseen_events + ) for unseen in unseen_events: unseen_to_child_events.setdefault(unseen, set()).add( - event.event_id + filtered_event.event.event_id ) # Now check if this event is an unseen prev event, if so # then we remove this event from the appropriate dicts. - for child_id in unseen_to_child_events.pop(event.event_id, []): - event_to_unseen_prevs[child_id].discard(event.event_id) + for child_id in unseen_to_child_events.pop( + filtered_event.event.event_id, [] + ): + event_to_unseen_prevs[child_id].discard( + filtered_event.event.event_id + ) - written_events.add(event.event_id) + written_events.add(filtered_event.event.event_id) logger.info( "Written %d events in room %s", len(written_events), room_id @@ -511,7 +520,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): """Interface used to specify how to write exported data.""" @abc.abstractmethod - def write_events(self, room_id: str, events: list[EventBase]) -> None: + def write_events(self, room_id: str, events: list[FilteredEvent]) -> None: """Write a batch of events for a room.""" raise NotImplementedError() diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f6517def9c9..2518716bc70 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -25,8 +25,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import AuthError, SynapseError -from synapse.events import EventBase -from synapse.events.utils import SerializeEventConfig +from synapse.events.utils import FilteredEvent, SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig @@ -102,19 +101,19 @@ async def get_stream( # joined room, we need to send down presence for those users. to_add: list[JsonDict] = [] for event in events: - if not isinstance(event, EventBase): + if not isinstance(event, FilteredEvent): continue - if event.type == EventTypes.Member: - if event.membership != Membership.JOIN: + if event.event.type == EventTypes.Member: + if event.event.membership != Membership.JOIN: continue # Send down presence. - if event.state_key == requester.user.to_string(): + if event.event.state_key == requester.user.to_string(): # Send down presence for everyone in the room. users: Iterable[str] = await self.store.get_users_in_room( - event.room_id + event.event.room_id ) else: - users = [event.state_key] + users = [event.event.state_key] states = await presence_handler.get_states(users) to_add.extend( @@ -155,7 +154,7 @@ async def get_event( room_id: str | None, event_id: str, show_redacted: bool = False, - ) -> EventBase | None: + ) -> FilteredEvent | None: """Retrieve a single specified event on behalf of a user. The event will be transformed in a user-specific and time-specific way, e.g. having unsigned metadata added or being erased depending on who is accessing. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1e5e98a59bd..9bcc047467b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -30,7 +30,7 @@ Membership, ) from synapse.api.errors import SynapseError -from synapse.events.utils import SerializeEventConfig +from synapse.events.utils import FilteredEvent, SerializeEventConfig from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource @@ -186,7 +186,7 @@ async def handle_room(event: RoomsForUser) -> None: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, + FilteredEvent.state(event=invite_event), time_now, config=serializer_options, ) @@ -225,7 +225,7 @@ async def handle_room(event: RoomsForUser) -> None: ) ).addErrback(unwrapFirstError) - messages = await filter_and_transform_events_for_client( + filtered_messages = await filter_and_transform_events_for_client( self._storage_controllers, user_id, messages, @@ -240,7 +240,7 @@ async def handle_room(event: RoomsForUser) -> None: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, + filtered_messages, time_now=time_now, config=serializer_options, ) @@ -250,7 +250,7 @@ async def handle_room(event: RoomsForUser) -> None: } d["state"] = await self._event_serializer.serialize_events( - current_state.values(), + [FilteredEvent.state(e) for e in current_state.values()], time_now=time_now, config=serializer_options, ) @@ -382,7 +382,9 @@ async def _room_initial_sync_parted( room_id, limit=pagin_config.limit, end_token=stream_token ) - messages = await filter_and_transform_events_for_client( + filtered_messages: list[ + FilteredEvent + ] = await filter_and_transform_events_for_client( self._storage_controllers, requester.user.to_string(), messages, @@ -402,7 +404,7 @@ async def _room_initial_sync_parted( "chunk": ( # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - messages, time_now, config=serialize_options + filtered_messages, time_now, config=serialize_options ) ), "start": await start_token.to_string(self.store), @@ -411,7 +413,9 @@ async def _room_initial_sync_parted( "state": ( # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now, config=serialize_options + [FilteredEvent.state(e) for e in room_state.values()], + time_now, + config=serialize_options, ) ), "presence": [], @@ -435,7 +439,7 @@ async def _room_initial_sync_joined( serialize_options = SerializeEventConfig(requester=requester) # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), + [FilteredEvent.state(e) for e in current_state.values()], time_now, config=serialize_options, ) @@ -496,7 +500,9 @@ async def get_receipts() -> list[JsonMapping]: ).addErrback(unwrapFirstError) ) - messages = await filter_and_transform_events_for_client( + filtered_messages: list[ + FilteredEvent + ] = await filter_and_transform_events_for_client( self._storage_controllers, requester.user.to_string(), messages, @@ -512,7 +518,7 @@ async def get_receipts() -> list[JsonMapping]: "chunk": ( # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - messages, time_now, config=serialize_options + filtered_messages, time_now, config=serialize_options ) ), "start": await start_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index eb016225159..319d70cbe8d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -61,7 +61,11 @@ UnpersistedEventContext, UnpersistedEventContextBase, ) -from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field +from synapse.events.utils import ( + FilteredEvent, + SerializeEventConfig, + maybe_upsert_event_field, +) from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME @@ -261,7 +265,7 @@ async def get_state_events( room_state = room_state_events[membership_event_id] events = await self._event_serializer.serialize_events( - room_state.values(), + [FilteredEvent.state(e) for e in room_state.values()], self.clock.time_msec(), config=SerializeEventConfig(requester=requester), ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7b9c8290564..8cbe4b63c88 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -29,6 +29,7 @@ from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.handlers.relations import BundledAggregations from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.opentracing import trace @@ -79,7 +80,7 @@ class GetMessagesResult: Everything needed to serialize a `/messages` response. """ - messages_chunk: list[EventBase] + messages_chunk: list[FilteredEvent] """ A list of room events. @@ -684,16 +685,18 @@ async def get_messages( events = await event_filter.filter(events) if not use_admin_priviledge: - events = await filter_and_transform_events_for_client( + filtered_events = await filter_and_transform_events_for_client( self._storage_controllers, user_id, events, is_peeking=(member_event_id is None), ) + else: + filtered_events = [FilteredEvent.admin_override(e) for e in events] # if after the filter applied there are no more events # return immediately - but there might be more in next_token batch - if not events: + if not filtered_events: return GetMessagesResult( messages_chunk=[], bundled_aggregations={}, @@ -703,16 +706,16 @@ async def get_messages( ) state = None - if event_filter and event_filter.lazy_load_members and len(events) > 0: + if event_filter and event_filter.lazy_load_members and len(filtered_events) > 0: # TODO: remove redundant members # FIXME: we also care about invite targets etc. state_filter = StateFilter.from_types( - (EventTypes.Member, event.sender) for event in events + (EventTypes.Member, event.event.sender) for event in filtered_events ) state_ids = await self._state_storage_controller.get_state_ids_for_event( - events[0].event_id, state_filter=state_filter + filtered_events[0].event.event_id, state_filter=state_filter ) if state_ids: @@ -720,11 +723,11 @@ async def get_messages( state = list(state_dict.values()) aggregations = await self._relations_handler.get_bundled_aggregations( - events, user_id + filtered_events, user_id ) return GetMessagesResult( - messages_chunk=events, + messages_chunk=filtered_events, bundled_aggregations=aggregations, state=state, start_token=from_token, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index d7d3002fbe3..ee4f8d672ee 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -33,7 +33,7 @@ from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event -from synapse.events.utils import SerializeEventConfig +from synapse.events.utils import FilteredEvent, SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent @@ -139,7 +139,7 @@ async def get_relations( # not passing them in here we should get a better cache hit rate). related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, - event=event, + event=event.event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -154,7 +154,9 @@ async def get_relations( [e.event_id for e in related_events] ) - events = await filter_and_transform_events_for_client( + filtered_events: list[ + FilteredEvent + ] = await filter_and_transform_events_for_client( self._storage_controllers, user_id, events, @@ -164,14 +166,14 @@ async def get_relations( # The relations returned for the requested event do include their # bundled aggregations. aggregations = await self.get_bundled_aggregations( - events, requester.user.to_string() + filtered_events, requester.user.to_string() ) now = self._clock.time_msec() serialize_options = SerializeEventConfig(requester=requester) return_value: JsonDict = { "chunk": await self._event_serializer.serialize_events( - events, + filtered_events, now, bundle_aggregations=aggregations, config=serialize_options, @@ -389,7 +391,7 @@ async def _get_threads_for_events( potential_events, _ = await self._main_store.get_relations_for_event( room_id, event_id, - event, + event.event, RelationTypes.THREAD, direction=Direction.FORWARDS, ) @@ -417,7 +419,7 @@ async def _get_threads_for_events( potential_events[-1].event_id, ) continue - latest_thread_event = event + latest_thread_event = event.event results[event_id] = _ThreadAggregation( latest_event=latest_thread_event, @@ -432,12 +434,12 @@ async def _get_threads_for_events( @trace async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str + self, filtered_events: Iterable[FilteredEvent], user_id: str ) -> dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: - events: The iterable of events to calculate bundled aggregations for. + filtered_events: The iterable of filtered events to calculate bundled aggregations for. user_id: The user requesting the bundled aggregations. Returns: @@ -453,7 +455,9 @@ async def get_bundled_aggregations( events_by_id = {} # A map of event ID to the relation in that event, if there is one. relations_by_id: dict[str, str] = {} - for event in events: + for filtered_event in filtered_events: + event = filtered_event.event + # State events do not get bundled aggregations. if event.is_state(): continue @@ -599,7 +603,9 @@ async def get_threads( # Limit the returned threads to those the user has participated in. events = [event for event in events if participated[event.event_id]] - events = await filter_and_transform_events_for_client( + filtered_events: list[ + FilteredEvent + ] = await filter_and_transform_events_for_client( self._storage_controllers, user_id, events, @@ -607,12 +613,12 @@ async def get_threads( ) aggregations = await self.get_bundled_aggregations( - events, requester.user.to_string() + filtered_events, requester.user.to_string() ) now = self._clock.time_msec() serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations + filtered_events, now, bundle_aggregations=aggregations ) return_value: JsonDict = {"chunk": serialized_events} diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1c3489a00e3..9074d7916b6 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -67,7 +67,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.snapshot import UnpersistedEventContext -from synapse.events.utils import copy_and_fixup_power_levels_contents +from synapse.events.utils import FilteredEvent, copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin from synapse.streams import EventSource @@ -109,9 +109,9 @@ @attr.s(slots=True, frozen=True, auto_attribs=True) class EventContext: - events_before: list[EventBase] - event: EventBase - events_after: list[EventBase] + events_before: list[FilteredEvent] + event: FilteredEvent + events_after: list[FilteredEvent] state: list[EventBase] aggregations: dict[str, BundledAggregations] start: str @@ -1916,9 +1916,9 @@ async def get_event_context( # The user is peeking if they aren't in the room already is_peeking = not is_user_in_room - async def filter_evts(events: list[EventBase]) -> list[EventBase]: + async def filter_evts(events: list[EventBase]) -> list[FilteredEvent]: if use_admin_priviledge: - return events + return [FilteredEvent.admin_override(e) for e in events] return await filter_and_transform_events_for_client( self._storage_controllers, user.to_string(), @@ -1946,31 +1946,33 @@ async def filter_evts(events: list[EventBase]) -> list[EventBase]: events_before = await event_filter.filter(events_before) events_after = await event_filter.filter(events_after) - events_before = await filter_evts(events_before) - events_after = await filter_evts(events_after) + filtered_events_before = await filter_evts(events_before) + filtered_events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - event = filtered[0] + filtered_event = filtered[0] # Fetch the aggregations. aggregations = await self._relations_handler.get_bundled_aggregations( - itertools.chain(events_before, (event,), events_after), + itertools.chain( + filtered_events_before, (filtered_event,), filtered_events_after + ), user.to_string(), ) - if events_after: - last_event_id = events_after[-1].event_id + if filtered_events_after: + last_event_id = filtered_events_after[-1].event.event_id else: last_event_id = event_id if event_filter and event_filter.lazy_load_members: state_filter = StateFilter.from_lazy_load_member_list( - ev.sender + ev.event.sender for ev in itertools.chain( - events_before, - (event,), - events_after, + filtered_events_before, + (filtered_event,), + filtered_events_after, ) ) else: @@ -1993,9 +1995,9 @@ async def filter_evts(events: list[EventBase]) -> list[EventBase]: token = StreamToken.START return EventContext( - events_before=events_before, - event=event, - events_after=events_after, + events_before=filtered_events_before, + event=filtered_event, + events_after=filtered_events_after, state=state_events, aggregations=aggregations, start=await token.copy_and_replace( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 56c047b0e89..30e072d011e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -29,8 +29,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter -from synapse.events import EventBase -from synapse.events.utils import SerializeEventConfig +from synapse.events.utils import FilteredEvent, SerializeEventConfig from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_and_transform_events_for_client @@ -48,7 +47,7 @@ class _SearchResult: # A mapping of event ID to the rank of that event. rank_map: dict[str, int] # A list of the resulting events. - allowed_events: list[EventBase] + allowed_events: list[FilteredEvent] # A map of room ID to results. room_groups: dict[str, JsonDict] # A set of event IDs to highlight. @@ -355,12 +354,12 @@ async def _search( state_results = {} if include_state: - for room_id in {e.room_id for e in search_result.allowed_events}: + for room_id in {e.event.room_id for e in search_result.allowed_events}: state = await self._storage_controllers.state.get_current_state(room_id) state_results[room_id] = list(state.values()) aggregations = await self._relations_handler.get_bundled_aggregations( - # Generate an iterable of EventBase for all the events that will be + # Generate an iterable of FilteredEvent for all the events that will be # returned, including contextual events. itertools.chain( # The events_before and events_after for each context. @@ -396,14 +395,14 @@ async def _search( results = [ { - "rank": search_result.rank_map[e.event_id], + "rank": search_result.rank_map[e.event.event_id], "result": await self._event_serializer.serialize_event( e, time_now, bundle_aggregations=aggregations, config=serialize_options, ), - "context": contexts.get(e.event_id, {}), + "context": contexts.get(e.event.event_id, {}), } for e in search_result.allowed_events ] @@ -417,7 +416,9 @@ async def _search( if state_results: rooms_cat_res["state"] = { room_id: await self._event_serializer.serialize_events( - state_events, time_now, config=serialize_options + [FilteredEvent.state(e) for e in state_events], + time_now, + config=serialize_options, ) for room_id, state_events in state_results.items() } @@ -485,19 +486,19 @@ async def _search_by_rank( filtered_events, ) - events.sort(key=lambda e: -rank_map[e.event_id]) + events.sort(key=lambda e: -rank_map[e.event.event_id]) allowed_events = events[: search_filter.limit] for e in allowed_events: rm = room_groups.setdefault( - e.room_id, {"results": [], "order": rank_map[e.event_id]} + e.event.room_id, {"results": [], "order": rank_map[e.event.event_id]} ) - rm["results"].append(e.event_id) + rm["results"].append(e.event.event_id) s = sender_group.setdefault( - e.sender, {"results": [], "order": rank_map[e.event_id]} + e.event.sender, {"results": [], "order": rank_map[e.event.event_id]} ) - s["results"].append(e.event_id) + s["results"].append(e.event.event_id) return ( _SearchResult( @@ -549,7 +550,7 @@ async def _search_by_recent( highlights = set() - room_events: list[EventBase] = [] + room_events: list[FilteredEvent] = [] i = 0 pagination_token = batch_token @@ -595,11 +596,11 @@ async def _search_by_recent( pagination_token = results[-1]["pagination_token"] for event in room_events: - group = room_groups.setdefault(event.room_id, {"results": []}) - group["results"].append(event.event_id) + group = room_groups.setdefault(event.event.room_id, {"results": []}) + group["results"].append(event.event.event_id) if room_events and len(room_events) >= search_filter.limit: - last_event_id = room_events[-1].event_id + last_event_id = room_events[-1].event.event_id pagination_token = results_map[last_event_id]["pagination_token"] # We want to respect the given batch group and group keys so @@ -632,7 +633,7 @@ async def _search_by_recent( async def _calculate_event_contexts( self, user: UserID, - allowed_events: list[EventBase], + allowed_events: list[FilteredEvent], before_limit: int, after_limit: int, include_profile: bool, @@ -658,7 +659,7 @@ async def _calculate_event_contexts( contexts = {} for event in allowed_events: res = await self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit + event.event.room_id, event.event.event_id, before_limit, after_limit ) logger.info( @@ -692,14 +693,14 @@ async def _calculate_event_contexts( if include_profile: senders = { - ev.sender + ev.event.sender for ev in itertools.chain(events_before, [event], events_after) } if events_after: - last_event_id = events_after[-1].event_id + last_event_id = events_after[-1].event.event_id else: - last_event_id = event.event_id + last_event_id = event.event.event_id state_filter = StateFilter.from_types( [(EventTypes.Member, sender) for sender in senders] @@ -718,6 +719,6 @@ async def _calculate_event_contexts( if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = context + contexts[event.event.event_id] = context return contexts diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index 6feb6c292e9..1cc587d4a7b 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -23,7 +23,7 @@ from synapse.api.constants import Direction, EventTypes, Membership from synapse.events import EventBase -from synapse.events.utils import strip_event +from synapse.events.utils import FilteredEvent, strip_event from synapse.handlers.relations import BundledAggregations from synapse.handlers.sliding_sync.extensions import SlidingSyncExtensionHandler from synapse.handlers.sliding_sync.room_lists import ( @@ -679,7 +679,7 @@ async def get_room_sync_data( # membership. Currently, we have to make all of these optional because # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 - timeline_events: list[EventBase] = [] + timeline_events: list[FilteredEvent] = [] bundled_aggregations: dict[str, BundledAggregations] | None = None limited: bool | None = None prev_batch_token: StreamToken | None = None @@ -739,7 +739,7 @@ async def get_room_sync_data( # Use `stream_ordering` for updates else paginate_room_events_by_stream_ordering ) - timeline_events, new_room_key, limited = await pagination_method( + raw_timeline_events, new_room_key, limited = await pagination_method( room_id=room_id, # The bounds are reversed so we can paginate backwards # (from newer to older events) starting at to_bound. @@ -752,13 +752,13 @@ async def get_room_sync_data( # We want to return the events in ascending order (the last event is the # most recent). - timeline_events.reverse() + raw_timeline_events.reverse() # Make sure we don't expose any events that the client shouldn't see timeline_events = await filter_and_transform_events_for_client( self.storage_controllers, user.to_string(), - timeline_events, + raw_timeline_events, is_peeking=room_membership_for_user_at_to_token.membership != Membership.JOIN, filter_send_to_client=True, @@ -778,12 +778,17 @@ async def get_room_sync_data( if from_token is not None: for timeline_event in reversed(timeline_events): # This fields should be present for all persisted events - assert timeline_event.internal_metadata.stream_ordering is not None - assert timeline_event.internal_metadata.instance_name is not None + assert ( + timeline_event.event.internal_metadata.stream_ordering + is not None + ) + assert ( + timeline_event.event.internal_metadata.instance_name is not None + ) persisted_position = PersistedEventPosition( - instance_name=timeline_event.internal_metadata.instance_name, - stream=timeline_event.internal_metadata.stream_ordering, + instance_name=timeline_event.event.internal_metadata.instance_name, + stream=timeline_event.event.internal_metadata.stream_ordering, ) if persisted_position.persisted_after( from_token.stream_token.room_key @@ -1061,13 +1066,13 @@ async def get_room_sync_data( if timeline_events is not None: for timeline_event in timeline_events: # Anyone who sent a message is relevant - timeline_membership.add(timeline_event.sender) + timeline_membership.add(timeline_event.event.sender) # We also care about invite, ban, kick, targets, # etc. - if timeline_event.type == EventTypes.Member: + if timeline_event.event.type == EventTypes.Member: timeline_membership.add( - timeline_event.state_key + timeline_event.event.state_key ) # The client needs to know the membership of everyone in @@ -1480,7 +1485,7 @@ async def _get_bump_stamp( self, room_id: str, to_token: StreamToken, - timeline: list[EventBase], + timeline: list[FilteredEvent], check_outside_timeline: bool, ) -> int | None: """Get a bump stamp for the room, if we have a bump event and it has @@ -1500,8 +1505,8 @@ async def _get_bump_stamp( # those matches. We iterate backwards and take the stream ordering # of the first event that matches the bump event types. for timeline_event in reversed(timeline): - if timeline_event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES: - new_bump_stamp = timeline_event.internal_metadata.stream_ordering + if timeline_event.event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES: + new_bump_stamp = timeline_event.event.internal_metadata.stream_ordering # All persisted events have a stream ordering assert new_bump_stamp is not None diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 9b7a01df142..4a324e9661c 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -761,7 +761,7 @@ async def handle_previously_room(room_id: str) -> None: # in the timeline to avoid bloating and blowing up the sync response # as the number of users in the room increases. (this behavior is part of the spec) initial_rooms_and_event_ids = [ - (room_id, event.event_id) + (room_id, event.event.event_id) for room_id in initial_rooms if room_id in actual_room_response_map for event in actual_room_response_map[room_id].timeline_events diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8ef5e2aa6c..c88f703ae98 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -43,6 +43,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.handlers.relations import BundledAggregations from synapse.logging import issue9533_logger from synapse.logging.context import current_context @@ -123,7 +124,7 @@ class SyncConfig: @attr.s(slots=True, frozen=True, auto_attribs=True) class TimelineBatch: prev_batch: StreamToken - events: Sequence[EventBase] + events: Sequence[FilteredEvent] limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. @@ -148,7 +149,7 @@ class JoinedSyncResult: state: StateMap[EventBase] ephemeral: list[JsonDict] account_data: list[JsonDict] - sticky: list[EventBase] + sticky: list[FilteredEvent] unread_notifications: JsonDict unread_thread_notifications: JsonDict summary: JsonDict | None @@ -699,6 +700,7 @@ async def _load_filtered_recents( log_kv({"limited": limited}) + filtered_recents: list[FilteredEvent] if potential_recents: recents = await sync_config.filter_collection.filter_room_timeline( potential_recents @@ -725,29 +727,32 @@ async def _load_filtered_recents( ) ) - recents = await filter_and_transform_events_for_client( + filtered_recents = await filter_and_transform_events_for_client( self._storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, ) - log_kv({"recents_after_visibility_filtering": len(recents)}) + log_kv({"recents_after_visibility_filtering": len(filtered_recents)}) else: - recents = [] + filtered_recents = [] if not limited or block_all_timeline: prev_batch_token = upto_token - if recents: - assert recents[0].internal_metadata.stream_ordering + if filtered_recents: + assert filtered_recents[0].event.internal_metadata.stream_ordering room_key = RoomStreamToken( - stream=recents[0].internal_metadata.stream_ordering - 1 + stream=filtered_recents[ + 0 + ].event.internal_metadata.stream_ordering + - 1 ) prev_batch_token = upto_token.copy_and_replace( StreamKeyType.ROOM, room_key ) return TimelineBatch( - events=recents, prev_batch=prev_batch_token, limited=False + events=filtered_recents, prev_batch=prev_batch_token, limited=False ) filtering_factor = 2 @@ -764,7 +769,7 @@ async def _load_filtered_recents( elif since_token and not newly_joined_room: since_key = since_token.room_key - while limited and len(recents) < timeline_limit and max_repeat: + while limited and len(filtered_recents) < timeline_limit and max_repeat: # For initial `/sync`, we want to view a historical section of the # timeline; to fetch events by `topological_ordering` (best # representation of the room DAG as others were seeing it at the time). @@ -835,26 +840,35 @@ async def _load_filtered_recents( ) ) - loaded_recents = await filter_and_transform_events_for_client( + loaded_filtered_recents: list[ + FilteredEvent + ] = await filter_and_transform_events_for_client( self._storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, ) - log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)}) + log_kv( + { + "loaded_recents_after_client_filtering": len( + loaded_filtered_recents + ) + } + ) - loaded_recents.extend(recents) - recents = loaded_recents + loaded_filtered_recents.extend(filtered_recents) + filtered_recents = loaded_filtered_recents max_repeat -= 1 - if len(recents) > timeline_limit: + if len(filtered_recents) > timeline_limit: limited = True - recents = recents[-timeline_limit:] - assert recents[0].internal_metadata.stream_ordering + filtered_recents = filtered_recents[-timeline_limit:] + assert filtered_recents[0].event.internal_metadata.stream_ordering room_key = RoomStreamToken( - stream=recents[0].internal_metadata.stream_ordering - 1 + stream=filtered_recents[0].event.internal_metadata.stream_ordering + - 1 ) prev_batch_token = upto_token.copy_and_replace(StreamKeyType.ROOM, room_key) @@ -865,12 +879,12 @@ async def _load_filtered_recents( if limited or newly_joined_room: bundled_aggregations = ( await self._relations_handler.get_bundled_aggregations( - recents, sync_config.user.to_string() + filtered_recents, sync_config.user.to_string() ) ) return TimelineBatch( - events=recents, + events=filtered_recents, prev_batch=prev_batch_token, # Also mark as limited if this is a new room or there has been a gap # (to force client to paginate the gap). @@ -976,8 +990,8 @@ async def compute_summary( # ...or ones which are in the timeline... for ev in batch.events: - if ev.type == EventTypes.Member: - existing_members.add(ev.state_key) + if ev.event.type == EventTypes.Member: + existing_members.add(ev.event.state_key) # ...and then ensure any missing ones get included in state. missing_hero_event_ids = [ @@ -1084,32 +1098,34 @@ async def compute_state_delta( first_event_by_sender_map = {} for event in batch.events: # Build the map from user IDs to the first timeline event they sent. - if event.sender not in first_event_by_sender_map: - first_event_by_sender_map[event.sender] = event + if event.event.sender not in first_event_by_sender_map: + first_event_by_sender_map[event.event.sender] = event.event # When using `state_after`, there is no special treatment with # regards to state also being in the `timeline`. Always fetch # relevant membership regardless of whether the state event is in # the `timeline`. if sync_config.use_state_after: - members_to_fetch.add(event.sender) + members_to_fetch.add(event.event.sender) # For `state`, the client is supposed to do a flawed re-construction # of state over time by starting with the given `state` and layering # on state from the `timeline` as you go (flawed because state # resolution). In this case, we only need their membership in # `state` when their membership isn't already in the `timeline`. - elif (EventTypes.Member, event.sender) not in timeline_state: - members_to_fetch.add(event.sender) + elif (EventTypes.Member, event.event.sender) not in timeline_state: + members_to_fetch.add(event.event.sender) # FIXME: we also care about invite targets etc. - if event.is_state(): - timeline_state[(event.type, event.state_key)] = event.event_id + if event.event.is_state(): + timeline_state[(event.event.type, event.event.state_key)] = ( + event.event.event_id + ) else: timeline_state = { - (event.type, event.state_key): event.event_id + (event.event.type, event.event.state_key): event.event.event_id for event in batch.events - if event.is_state() + if event.event.is_state() } # Now calculate the state to return in the sync response for the room. @@ -1340,7 +1356,7 @@ async def _compute_state_delta_for_full_sync( # timeline, but that is good enough here. state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, + batch.events[0].event.event_id, state_filter=state_filter, await_full_state=await_full_state, ) @@ -1470,10 +1486,10 @@ async def _compute_state_delta_for_incremental_sync( prev_event_id = last_event_id_prev_batch for e in batch.events: - if e.prev_event_ids() != [prev_event_id]: + if e.event.prev_event_ids() != [prev_event_id]: is_linear_timeline = False break - prev_event_id = e.event_id + prev_event_id = e.event.event_id if is_linear_timeline and not batch.limited: state_ids: StateMap[str] = {} @@ -1487,7 +1503,7 @@ async def _compute_state_delta_for_incremental_sync( state_ids = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, + batch.events[0].event.event_id, # we only want members! state_filter=StateFilter.from_types( (EventTypes.Member, member) @@ -1501,7 +1517,7 @@ async def _compute_state_delta_for_incremental_sync( if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, + batch.events[0].event.event_id, state_filter=state_filter, await_full_state=await_full_state, ) @@ -2854,7 +2870,7 @@ async def _generate_room_entry( # if there are membership changes in the timeline, or # if membership has changed during a gappy sync, or # if this is an initial sync. - any(ev.type == EventTypes.Member for ev in batch.events) + any(ev.event.type == EventTypes.Member for ev in batch.events) or ( # XXX: this may include false positives in the form of LL # members which have snuck into state @@ -2870,7 +2886,7 @@ async def _generate_room_entry( if room_builder.rtype == "joined": unread_notifications: dict[str, int] = {} - sticky_events: list[EventBase] = [] + sticky_events: list[FilteredEvent] = [] if sticky_event_ids: # As per MSC4354: # Remove sticky events that are already in the timeline, else we will needlessly duplicate @@ -2880,7 +2896,7 @@ async def _generate_room_entry( # This is particularly important given the risk of sticky events spam since # anyone can send sticky events, so halving the bandwidth on average for each sticky # event is helpful. - timeline_event_id_set = {ev.event_id for ev in batch.events} + timeline_event_id_set = {ev.event.event_id for ev in batch.events} # Must preserve sticky event stream order sticky_event_ids = [ e for e in sticky_event_ids if e not in timeline_event_id_set @@ -3144,7 +3160,8 @@ def calculate_user_changes(self) -> tuple[AbstractSet[str], AbstractSet[str]]: if self.since_token: for joined_sync in self.joined: it = itertools.chain( - joined_sync.state.values(), joined_sync.timeline.events + joined_sync.state.values(), + (e.event for e in joined_sync.timeline.events), ) for event in it: if event.type == EventTypes.Member: diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index 539672c7fe1..29cb045d001 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -53,7 +53,7 @@ async def get_thread_subscription_settings( raise NotFoundError("No such thread root") return await self.store.get_subscription_for_thread( - user_id.to_string(), event.room_id, thread_root_event_id + user_id.to_string(), event.event.room_id, thread_root_event_id ) async def subscribe_user_to_thread( @@ -103,7 +103,7 @@ async def subscribe_user_to_thread( ) if autosub_cause_event is None: raise NotFoundError("Automatic subscription event not found") - relation = relation_from_event(autosub_cause_event) + relation = relation_from_event(autosub_cause_event.event) if ( relation is None or relation.rel_type != RelationTypes.THREAD @@ -115,7 +115,9 @@ async def subscribe_user_to_thread( errcode=Codes.MSC4306_NOT_IN_THREAD, ) - automatic_event_orderings = EventOrderings.from_event(autosub_cause_event) + automatic_event_orderings = EventOrderings.from_event( + autosub_cause_event.event + ) else: automatic_event_orderings = None @@ -174,7 +176,7 @@ async def unsubscribe_user_from_thread( outcome = await self.store.unsubscribe_user_from_thread( user_id.to_string(), - event.room_id, + event.event.room_id, thread_root_event_id, ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 93d438def71..f1cec74462f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -41,6 +41,7 @@ from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.handlers.presence import format_user_presence_state from synapse.logging import issue9533_logger from synapse.logging.context import PreserveLoggingContext @@ -210,7 +211,7 @@ def new_listener(self, token: StreamToken) -> "Deferred[StreamToken]": @attr.s(slots=True, frozen=True, auto_attribs=True) class EventStreamResult: - events: list[JsonDict | EventBase] + events: list[JsonDict | FilteredEvent] start_token: StreamToken end_token: StreamToken @@ -765,7 +766,7 @@ async def check_for_updates( # The events fetched from each source are a JsonDict, EventBase, or # UserPresenceState, but see below for UserPresenceState being # converted to JsonDict. - events: list[JsonDict | EventBase] = [] + events: list[JsonDict | FilteredEvent] = [] end_token = from_token for keyname, source in self.event_sources.sources.get_sources(): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index d18630e80ba..1ebbc6d4f39 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -543,8 +543,10 @@ async def _get_notif_vars( results.events_before + [notif_event], ) - for event in the_events: - messagevars = await self._get_message_vars(notif, event, room_state_ids) + for filtered_event in the_events: + messagevars = await self._get_message_vars( + notif, filtered_event.event, room_state_ids + ) if messagevars is not None: ret["messages"].append(messagevars) diff --git a/synapse/rest/admin/events.py b/synapse/rest/admin/events.py index 8da7a67820a..1c311b04713 100644 --- a/synapse/rest/admin/events.py +++ b/synapse/rest/admin/events.py @@ -3,6 +3,7 @@ from synapse.api.errors import NotFoundError from synapse.events.utils import ( + FilteredEvent, SerializeEventConfig, format_event_raw, ) @@ -66,7 +67,9 @@ async def on_GET( ) res = { "event": await self._event_serializer.serialize_event( - event, self._clock.time_msec(), config=config + FilteredEvent.admin_override(event), + self._clock.time_msec(), + config=config, ) } diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index a886859ffab..61511b93601 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -29,6 +29,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events.utils import ( + FilteredEvent, SerializeEventConfig, ) from synapse.handlers.pagination import ( @@ -529,7 +530,9 @@ async def on_GET( ) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events(events.values(), now) + room_state = await self._event_serializer.serialize_events( + [FilteredEvent.state(e) for e in events.values()], now + ) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -897,7 +900,8 @@ async def on_GET( bundle_aggregations=event_context.aggregations, ), "state": await self._event_serializer.serialize_events( - event_context.state, time_now + [FilteredEvent.state(e) for e in event_context.state], + time_now, ), "start": event_context.start, "end": event_context.end, diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 2420e9fffbe..f80a43b2978 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -24,6 +24,7 @@ from synapse.api.constants import ReceiptTypes from synapse.events.utils import ( + FilteredEvent, SerializeEventConfig, format_event_for_client_v2_without_room_id, ) @@ -111,7 +112,7 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: "ts": pa.received_ts, "event": ( await self._event_serializer.serialize_event( - notif_events[pa.event_id], + FilteredEvent(event=notif_events[pa.event_id], membership=None), now, config=serialize_options, ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 65d9c130efc..83664814a6c 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -53,6 +53,7 @@ from synapse.api.filtering import Filter from synapse.events.utils import ( EventClientSerializer, + FilteredEvent, SerializeEventConfig, format_event_for_client_v2, ) @@ -286,7 +287,7 @@ async def on_GET( if format == "event": event = await self._event_serializer.serialize_event( - data, + FilteredEvent.state(data), self.clock.time_msec(), config=SerializeEventConfig( event_format=format_event_for_client_v2, @@ -866,7 +867,9 @@ async def encode_messages_response( serialized_result[ "state" ] = await serialize_deps.event_serializer.serialize_events( - get_messages_result.state, time_now, config=serialize_options + [FilteredEvent.state(e) for e in get_messages_result.state], + time_now, + config=serialize_options, ) return serialized_result @@ -1172,7 +1175,7 @@ async def on_GET( config=serializer_options, ), "state": await self._event_serializer.serialize_events( - event_context.state, + [FilteredEvent.state(e) for e in event_context.state], time_now, config=serializer_options, ), diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 710d097eab0..c3cf0dc3c4d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -18,7 +18,6 @@ # [This file includes modifications made by New Vector Limited] # # -import itertools import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Mapping @@ -31,6 +30,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( + FilteredEvent, SerializeEventConfig, format_event_for_client_v2_without_room_id, format_event_raw, @@ -448,7 +448,9 @@ async def encode_invited( invited = {} for room in rooms: invite = await self._event_serializer.serialize_event( - room.invite, time_now, config=serialize_options + FilteredEvent.state(event=room.invite), + time_now, + config=serialize_options, ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -484,7 +486,9 @@ async def encode_knocked( knocked = {} for room in rooms: knock = await self._event_serializer.serialize_event( - room.knock, time_now, config=serialize_options + FilteredEvent.state(event=room.knock), + time_now, + config=serialize_options, ) # Extract the `unsigned` key from the knock event. @@ -574,7 +578,7 @@ async def encode_room( state_events = state_dict.values() - for event in itertools.chain(state_events, timeline_events): + for event in state_events: # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: @@ -584,9 +588,21 @@ async def encode_room( room.room_id, event.room_id, ) + for filtered_event in timeline_events: + # We've had bug reports that events were coming down under the + # wrong room. + if filtered_event.event.room_id != room.room_id: + logger.warning( + "Event %r is under room %r instead of %r", + filtered_event.event.event_id, + room.room_id, + filtered_event.event.room_id, + ) serialized_state = await self._event_serializer.serialize_events( - state_events, time_now, config=serialize_options + [FilteredEvent.state(e) for e in state_events], + time_now, + config=serialize_options, ) serialized_timeline = await self._event_serializer.serialize_events( timeline_events, @@ -974,7 +990,7 @@ async def encode_rooms( ): serialized_required_state = ( await self.event_serializer.serialize_events( - room_result.required_state, + [FilteredEvent.state(e) for e in room_result.required_state], time_now, config=serialize_options, ) diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 694b3e1645e..1a84bf1ff8f 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -34,6 +34,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase +from synapse.events.utils import FilteredEvent from synapse.types import ( DeviceListUpdates, JsonDict, @@ -185,7 +186,7 @@ class StrippedHero: # Should be empty for invite/knock rooms with `stripped_state` required_state: list[EventBase] # Should be empty for invite/knock rooms with `stripped_state` - timeline_events: list[EventBase] + timeline_events: list[FilteredEvent] bundled_aggregations: dict[str, "BundledAggregations"] | None # Optional because it's only relevant to invite/knock rooms stripped_state: list[JsonDict] diff --git a/synapse/visibility.py b/synapse/visibility.py index 5ba2a14a24a..fc3e9dfa498 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -31,14 +31,13 @@ from synapse.api.constants import ( EventTypes, - EventUnsignedContentFields, HistoryVisibility, JoinRules, Membership, ) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.events.utils import clone_event, prune_event +from synapse.events.utils import FilteredEvent, prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -82,7 +81,7 @@ async def filter_and_transform_events_for_client( is_peeking: bool = False, always_include_ids: frozenset[str] = frozenset(), filter_send_to_client: bool = True, -) -> list[EventBase]: +) -> list[FilteredEvent]: """ Check which events a user is allowed to see. If the user can see the event but its sender asked for their data to be erased, prune the content of the event. @@ -102,8 +101,8 @@ async def filter_and_transform_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - The filtered events. The `unsigned` data is annotated with the membership state - of `user_id` at each event. + The filtered events, wrapped in FilteredEvent with the requesting user's + membership at each event annotated for use during serialization (MSC4115). """ # Filter out events that have been soft failed so that we don't relay them # to clients, unless they're a server admin and want that to happen. @@ -176,7 +175,7 @@ async def filter_and_transform_events_for_client( room_id ] = await storage.main.get_retention_policy_for_room(room_id) - def allowed(event: EventBase) -> EventBase | None: + def allowed(event: EventBase) -> FilteredEvent | None: state_after_event = event_id_to_state.get(event.event_id) filtered = _check_client_allowed_to_see_event( user_id=user_id, @@ -233,28 +232,9 @@ def allowed(event: EventBase) -> EventBase | None: else Membership.LEAVE ) - # Copy the event before updating the unsigned data: this shouldn't be persisted - # to the cache! - cloned = clone_event(filtered) - cloned.unsigned[EventUnsignedContentFields.MEMBERSHIP] = user_membership - if storage.main.config.experimental.msc4354_enabled: - sticky_duration = cloned.sticky_duration() - if sticky_duration: - now_ms = storage.main.clock.time_msec() - expires_at = ( - # min() ensures that the origin server can't lie about the time and - # send the event 'in the future', as that would allow them to exceed - # the 1 hour limit on stickiness duration. - min(cloned.origin_server_ts, now_ms) + sticky_duration.as_millis() - ) - if expires_at > now_ms: - cloned.unsigned[EventUnsignedContentFields.STICKY_TTL] = ( - expires_at - now_ms - ) - - return cloned + return FilteredEvent(event=filtered, membership=user_membership) - # Check each event: gives an iterable of None or (a modified) EventBase. + # Check each event: gives an iterable of None or a FilteredEvent. filtered_events = map(allowed, events) # Turn it into a list and remove None entries before returning. diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py index 72ade457589..e0ebdf0bcac 100644 --- a/tests/events/test_auto_accept_invites.py +++ b/tests/events/test_auto_accept_invites.py @@ -380,7 +380,7 @@ async def test_ignore_invite_for_missing_user(self) -> None: join_updates, _ = sync_join(self, inviting_user_id) # Assert that the last event in the room was not a member event for the target user. self.assertEqual( - join_updates[0].timeline.events[-1].content["membership"], "invite" + join_updates[0].timeline.events[-1].event.content["membership"], "invite" ) @override_config( @@ -423,7 +423,7 @@ async def test_ignore_invite_for_deactivated_user(self) -> None: join_updates, b = sync_join(self, inviting_user_id) # Assert that the last event in the room was not a member event for the target user. self.assertEqual( - join_updates[0].timeline.events[-1].content["membership"], "invite" + join_updates[0].timeline.events[-1].event.content["membership"], "invite" ) @override_config( @@ -466,7 +466,7 @@ async def test_ignore_invite_for_suspended_user(self) -> None: join_updates, b = sync_join(self, inviting_user_id) # Assert that the last event in the room was not a member event for the target user. self.assertEqual( - join_updates[0].timeline.events[-1].content["membership"], "invite" + join_updates[0].timeline.events[-1].event.content["membership"], "invite" ) @override_config( @@ -509,7 +509,7 @@ async def test_ignore_invite_for_locked_user(self) -> None: join_updates, b = sync_join(self, inviting_user_id) # Assert that the last event in the room was not a member event for the target user. self.assertEqual( - join_updates[0].timeline.events[-1].content["membership"], "invite" + join_updates[0].timeline.events[-1].event.content["membership"], "invite" ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index af44b5dec1b..12ef42866d2 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -28,6 +28,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( + FilteredEvent, PowerLevelsContent, SerializeEventConfig, _split_field, @@ -655,7 +656,7 @@ def serialize( ) -> JsonDict: return self.get_success( self._event_serializer.serialize_event( - ev, + FilteredEvent(event=ev, membership=None), 1479807801915, config=SerializeEventConfig( only_event_fields=fields, diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 49bd3ba3f4c..a368363d7e3 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -81,7 +81,8 @@ def test_single_public_joined_room(self) -> None: # Check that the right number of events were written counter = Counter( - (event.type, getattr(event, "state_key", None)) for event in written_events + (event.event.type, getattr(event.event, "state_key", None)) + for event in written_events ) self.assertEqual(counter[(EventTypes.Message, None)], 2) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) @@ -119,7 +120,8 @@ def test_single_private_joined_room(self) -> None: # Check that the right number of events were written counter = Counter( - (event.type, getattr(event, "state_key", None)) for event in written_events + (event.event.type, getattr(event.event, "state_key", None)) + for event in written_events ) self.assertEqual(counter[(EventTypes.Message, None)], 1) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) @@ -151,7 +153,8 @@ def test_single_left_room(self) -> None: # Check that the right number of events were written counter = Counter( - (event.type, getattr(event, "state_key", None)) for event in written_events + (event.event.type, getattr(event.event, "state_key", None)) + for event in written_events ) self.assertEqual(counter[(EventTypes.Message, None)], 2) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) @@ -192,7 +195,8 @@ def test_single_left_rejoined_private_room(self) -> None: # Check that the right number of events were written counter = Counter( - (event.type, getattr(event, "state_key", None)) for event in written_events + (event.event.type, getattr(event.event, "state_key", None)) + for event in written_events ) self.assertEqual(counter[(EventTypes.Message, None)], 2) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 18ec2ca6b6d..b9dee1c9547 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -307,7 +307,7 @@ def test_ban_wins_race_with_join(self) -> None: self.assertEqual(len(alice_sync_result.joined), 1) self.assertEqual(alice_sync_result.joined[0].room_id, room_id) last_room_creation_event_id = ( - alice_sync_result.joined[0].timeline.events[-1].event_id + alice_sync_result.joined[0].timeline.events[-1].event.event_id ) # Eve, a ne'er-do-well, registers. @@ -402,7 +402,7 @@ def test_state_includes_changes_on_forks(self) -> None: ) ) last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id + initial_sync_result.joined[0].timeline.events[-1].event.event_id ) # Send a state event, and a regular event, both using the same prev ID @@ -437,7 +437,7 @@ def test_state_includes_changes_on_forks(self) -> None: self.assertEqual(room_sync.room_id, room_id) self.assertTrue(room_sync.timeline.limited) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e3_event, e4_event], ) self.assertEqual( @@ -476,7 +476,7 @@ def test_state_includes_changes_on_forks_when_events_excluded(self) -> None: ) ) last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id + initial_sync_result.joined[0].timeline.events[-1].event.event_id ) # Send a state event, and a regular event, both using the same prev ID @@ -521,7 +521,7 @@ def test_state_includes_changes_on_forks_when_events_excluded(self) -> None: self.assertEqual(room_sync.room_id, room_id) self.assertTrue(room_sync.timeline.limited) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e3_event], ) self.assertEqual( @@ -563,7 +563,7 @@ def test_state_includes_changes_on_long_lived_forks(self) -> None: ) ) last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id + initial_sync_result.joined[0].timeline.events[-1].event.event_id ) # Send a state event, and a regular event, both using the same prev ID @@ -593,7 +593,7 @@ def test_state_includes_changes_on_long_lived_forks(self) -> None: self.assertEqual(room_sync.room_id, room_id) self.assertTrue(room_sync.timeline.limited) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e3_event], ) @@ -632,7 +632,7 @@ def test_state_includes_changes_on_long_lived_forks(self) -> None: self.assertEqual(room_sync.room_id, room_id) self.assertFalse(room_sync.timeline.limited) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e4_event], ) @@ -701,7 +701,7 @@ def test_state_includes_changes_on_ungappy_syncs(self) -> None: ) ) last_room_creation_event_id = ( - initial_sync_result.joined[0].timeline.events[-1].event_id + initial_sync_result.joined[0].timeline.events[-1].event.event_id ) # Send a state event, and a regular event, both using the same prev ID @@ -728,7 +728,7 @@ def test_state_includes_changes_on_ungappy_syncs(self) -> None: room_sync = initial_sync_result.joined[0] self.assertEqual(room_sync.room_id, room_id) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e3_event], ) if self.use_state_after: @@ -757,7 +757,7 @@ def test_state_includes_changes_on_ungappy_syncs(self) -> None: self.assertEqual(room_sync.room_id, room_id) self.assertFalse(room_sync.timeline.limited) self.assertEqual( - [e.event_id for e in room_sync.timeline.events], + [e.event.event_id for e in room_sync.timeline.events], [e4_event, e5_event], ) @@ -855,7 +855,7 @@ def test_archived_rooms_do_not_include_state_after_leave( # The last three events in the timeline should be those leading up to the # leave self.assertEqual( - [e.event_id for e in sync_room_result.timeline.events[-3:]], + [e.event.event_id for e in sync_room_result.timeline.events[-3:]], [before_message_event, before_state_event, leave_event], ) @@ -947,7 +947,7 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch( ) event_ids = [] for event in sync_result.joined[0].timeline.events: - event_ids.append(event.event_id) + event_ids.append(event.event.event_id) self.assertNotIn(call_event.event_id, event_ids) # it will come down in a private room, though @@ -995,7 +995,7 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch( ) priv_event_ids = [] for event in private_sync_result.joined[0].timeline.events: - priv_event_ids.append(event.event_id) + priv_event_ids.append(event.event.event_id) self.assertIn(private_call_event.event_id, priv_event_ids) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 82a3b5b3378..1f8d9154ca6 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -23,6 +23,7 @@ from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes +from synapse.events.utils import FilteredEvent from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -173,7 +174,9 @@ def test_visibility(self) -> None: # We should only get one event back. self.assertEqual(len(filtered_events), 1, filtered_events) # That event should be the second, not outdated event. - self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + self.assertEqual( + filtered_events[0].event.event_id, valid_event_id, filtered_events + ) def _test_retention_event_purged(self, room_id: str, increment: float) -> None: """Run the following test scenario to test the message retention policy support: @@ -253,7 +256,11 @@ def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: assert event is not None time_now = self.clock.time_msec() - serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + serialized = self.get_success( + self.serializer.serialize_event( + FilteredEvent(event=event, membership=None), time_now + ) + ) return serialized diff --git a/tests/test_visibility.py b/tests/test_visibility.py index b50faa2a499..9a5efbdd399 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import AccountDataTypes, EventUnsignedContentFields +from synapse.api.constants import AccountDataTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext @@ -341,7 +341,7 @@ def test_normal_operation_as_admin(self) -> None: ) self.assertEqual( [e.event_id for e in [self.regular_event]], - [e.event_id for e in filtered_events], + [e.event.event_id for e in filtered_events], ) def test_see_soft_failed_events(self) -> None: @@ -380,7 +380,7 @@ def test_see_soft_failed_events(self) -> None: ) self.assertEqual( [e.event_id for e in [self.regular_event, self.soft_failed_event]], - [e.event_id for e in filtered_events], + [e.event.event_id for e in filtered_events], ) def test_see_policy_server_spammy_events(self) -> None: @@ -427,7 +427,7 @@ def test_see_policy_server_spammy_events(self) -> None: ) self.assertEqual( [e.event_id for e in [self.regular_event, self.spammy_event]], - [e.event_id for e in filtered_events], + [e.event.event_id for e in filtered_events], ) def test_see_soft_failed_and_policy_server_spammy_events(self) -> None: @@ -477,7 +477,7 @@ def test_see_soft_failed_and_policy_server_spammy_events(self) -> None: e.event_id for e in [self.regular_event, self.soft_failed_event, self.spammy_event] ], - [e.event_id for e in filtered_events], + [e.event.event_id for e in filtered_events], ) @@ -559,14 +559,11 @@ def test_joined_history_visibility(self) -> None: # and messages sent between the two, but not before or after. self.assertEqual( [e.event_id for e in [join_event, during_event, leave_event]], - [e.event_id for e in joiner_filtered_events], + [e.event.event_id for e in joiner_filtered_events], ) self.assertEqual( ["join", "join", "leave"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in joiner_filtered_events - ], + [e.membership for e in joiner_filtered_events], ) # The resident user should see all the events. @@ -581,14 +578,11 @@ def test_joined_history_visibility(self) -> None: after_event, ] ], - [e.event_id for e in resident_filtered_events], + [e.event.event_id for e in resident_filtered_events], ) self.assertEqual( ["join", "join", "join", "join", "join"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in resident_filtered_events - ], + [e.membership for e in resident_filtered_events], ) @@ -651,15 +645,12 @@ def test_out_of_band_invite_rejection(self) -> None: ) ) self.assertEqual( - [e.event_id for e in filtered_events], + [e.event.event_id for e in filtered_events], [e.event_id for e in [invite_event, reject_event]], ) self.assertEqual( ["invite", "leave"], - [ - e.unsigned[EventUnsignedContentFields.MEMBERSHIP] - for e in filtered_events - ], + [e.membership for e in filtered_events], ) # other users should see neither