Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions backend/src/controllers/sse.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { z } from 'zod';

const subscribeSchema = z.object({
streams: z.array(z.string()).optional().default([]),
users: z.array(z.string()).optional().default([]),
all: z.boolean().optional().default(false),
});

Expand Down Expand Up @@ -41,30 +42,38 @@ export const subscribe = async (req: Request, res: Response) => {
}

const { publicKey } = (req as AuthenticatedRequest).user;
const { streams, all } = subscribeSchema.parse(req.query);
const { streams, users, all } = subscribeSchema.parse(req.query);

// Scope: only streams where the authenticated user is sender or recipient
const ownedStreams = await prisma.stream.findMany({
where: { OR: [{ sender: publicKey }, { recipient: publicKey }] },
select: { streamId: true },
select: { streamId: true, sender: true, recipient: true },
});
const ownedIds = new Set(ownedStreams.map((s: any) => String(s.streamId)));
const ownedIds = new Set(ownedStreams.map((s) => String(s.streamId)));
const allowedUserKeys = new Set<string>([publicKey]);
for (const stream of ownedStreams) {
allowedUserKeys.add(stream.sender);
allowedUserKeys.add(stream.recipient);
}

let subscriptions: string[];
if (all) {
// "all" still scoped to the user's own streams
subscriptions = [...ownedIds] as string[];
subscriptions = [...ownedIds];
} else if (streams.length > 0) {
// Only allow subscribing to streams the user owns
subscriptions = streams.filter((id) => ownedIds.has(id));
} else {
subscriptions = [...ownedIds] as string[];
subscriptions = [...ownedIds];
}

// Always add user-scoped subscription key
subscriptions.push(`user:${publicKey}`);
const userSubscriptions = new Set<string>([`user:${publicKey}`]);
for (const key of users.filter((k) => allowedUserKeys.has(k))) {
userSubscriptions.add(`user:${key}`);
}
subscriptions.push(...userSubscriptions);

const clientId = `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
const clientId = `${Date.now()}-${Math.random().toString(36).slice(2, 11)}`;

res.writeHead(200, {
'Content-Type': 'text/event-stream',
Expand All @@ -77,11 +86,11 @@ export const subscribe = async (req: Request, res: Response) => {
res.write(`data: ${JSON.stringify({ type: 'connected', clientId, requestId })}\n\n`);

sseService.addClient(clientId, res, subscriptions, sourceIp);
} catch (error: any) {
if (error.name === 'ZodError') {
} catch (error: unknown) {
if (error instanceof z.ZodError) {
return res.status(400).json({
message: 'Invalid subscription parameters',
errors: error.errors,
errors: error.issues,
});
}
return res.status(500).json({ message: 'Internal server error' });
Expand Down
23 changes: 23 additions & 0 deletions backend/tests/sse.controller.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,28 @@ describe('SSE Controller', () => {
await subscribe(req as Request, res as Response);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({
message: 'Invalid subscription parameters',
errors: expect.arrayContaining([expect.objectContaining({ code: expect.any(String) })]),
}),
);
});

it('should include allowed users query subscriptions', async () => {
(sseService.isShuttingDown as any).mockReturnValue(false);
(sseService.checkCapacity as any).mockReturnValue({ allowed: true });
(req as any).user = { publicKey: 'GUSER1' };
req.query = { users: ['GCOUNTER', 'GOTHER'] };
(prisma.stream.findMany as any).mockResolvedValue([
{ streamId: 1, sender: 'GUSER1', recipient: 'GCOUNTER' },
]);

await subscribe(req as Request, res as Response);

const subscriptions = (sseService.addClient as any).mock.calls[0][2] as string[];
expect(subscriptions).toContain('user:GUSER1');
expect(subscriptions).toContain('user:GCOUNTER');
expect(subscriptions).not.toContain('user:GOTHER');
});
});
Loading