diff --git a/apps/web/src/lib/ai-gateway/llm-proxy-helpers.test.ts b/apps/web/src/lib/ai-gateway/llm-proxy-helpers.test.ts index 96697d8acc..da9630c7ad 100644 --- a/apps/web/src/lib/ai-gateway/llm-proxy-helpers.test.ts +++ b/apps/web/src/lib/ai-gateway/llm-proxy-helpers.test.ts @@ -1,4 +1,6 @@ import { describe, it, expect, beforeEach } from '@jest/globals'; +import { spawnSync } from 'node:child_process'; +import { join } from 'node:path'; import type { MicrodollarUsageContext, MicrodollarUsageStats } from './processUsage.types'; // `countAndStoreEditUsage` schedules the usage write through `next/server`'s @@ -32,8 +34,125 @@ import { parseEmbeddingUsageFromResponse, parseEditUsageFromResponse, parseTranscriptionUsageFromResponse, + wrapInSafeNextResponse, } from './llm-proxy-helpers'; +describe('wrapInSafeNextResponse', () => { + it('keeps a fetched response alive until its returned body is consumed', () => { + const tsx = require.resolve('tsx/cli'); + const fixture = join(__dirname, 'wrap-safe-next-response.gc.ts'); + const result = spawnSync(process.execPath, ['--conditions=react-server', tsx, fixture], { + encoding: 'utf8', + env: { + ...process.env, + NODE_ENV: 'test', + NODE_OPTIONS: + `${process.env.NODE_OPTIONS ?? ''} --expose-gc --conditions=react-server`.trim(), + UPSTASH_REDIS_REST_TOKEN: process.env.UPSTASH_REDIS_REST_TOKEN ?? 'test', + UPSTASH_REDIS_REST_URL: process.env.UPSTASH_REDIS_REST_URL ?? 'http://127.0.0.1', + }, + timeout: 30_000, + }); + + expect({ + status: result.status, + signal: result.signal, + error: result.error?.message, + stdout: result.stdout, + stderr: result.stderr, + }).toEqual({ + status: 0, + signal: null, + error: undefined, + stdout: '', + stderr: '', + }); + }); + + it('preserves exact bytes without locking the source eagerly', async () => { + const bytes = Uint8Array.from([0, 255, 1, 128, 13, 10]); + const source = new ReadableStream({ + start(controller) { + controller.enqueue(bytes.subarray(0, 2)); + controller.enqueue(bytes.subarray(2)); + controller.close(); + }, + }); + const wrapped = wrapInSafeNextResponse(new Response(source)); + + expect(source.locked).toBe(false); + expect(new Uint8Array(await wrapped.arrayBuffer())).toEqual(bytes); + expect(source.locked).toBe(false); + }); + + it('propagates source errors', async () => { + const error = new Error('source failed'); + const source = new ReadableStream({ + pull(controller) { + controller.error(error); + }, + }); + const wrapped = wrapInSafeNextResponse(new Response(source)); + + await expect(wrapped.text()).rejects.toBe(error); + expect(source.locked).toBe(false); + }); + + it('propagates cancellation to the source', async () => { + const cancelled = jest.fn(); + const reason = new Error('consumer stopped'); + const source = new ReadableStream({ + pull(controller) { + controller.enqueue(Uint8Array.of(1)); + }, + cancel: cancelled, + }); + const wrapped = wrapInSafeNextResponse(new Response(source)); + const reader = wrapped.body?.getReader(); + + await reader?.read(); + await reader?.cancel(reason); + + expect(cancelled).toHaveBeenCalledWith(reason); + expect(source.locked).toBe(false); + }); + + it('propagates cancellation before the first pull', async () => { + const cancelled = jest.fn(); + const reason = new Error('consumer stopped early'); + const source = new ReadableStream({ cancel: cancelled }); + const wrapped = wrapInSafeNextResponse(new Response(source)); + + await wrapped.body?.cancel(reason); + + expect(cancelled).toHaveBeenCalledWith(reason); + expect(source.locked).toBe(false); + }); + + it('does not close an already cancelled stream when a pending pull settles', async () => { + const cancelled = jest.fn(); + const reason = new Error('consumer stopped during pull'); + const source = new ReadableStream({ cancel: cancelled }); + const wrapped = wrapInSafeNextResponse(new Response(source)); + const reader = wrapped.body?.getReader(); + const pending = reader?.read(); + + await Promise.resolve(); + await reader?.cancel(reason); + + await expect(pending).resolves.toEqual({ done: true, value: undefined }); + expect(cancelled).toHaveBeenCalledWith(reason); + expect(source.locked).toBe(false); + }); + + it('preserves bodyless responses', () => { + const wrapped = wrapInSafeNextResponse(new Response(null, { status: 204 })); + + expect(wrapped.status).toBe(204); + expect(wrapped.body).toBeNull(); + }); +}); + describe('checkOrganizationModelRestrictions', () => { describe('enterprise plan - model deny list restrictions', () => { it('should allow model when it is not in the deny list on enterprise plan', () => { diff --git a/apps/web/src/lib/ai-gateway/llm-proxy-helpers.ts b/apps/web/src/lib/ai-gateway/llm-proxy-helpers.ts index 14ae12088d..f1c0ea024f 100644 --- a/apps/web/src/lib/ai-gateway/llm-proxy-helpers.ts +++ b/apps/web/src/lib/ai-gateway/llm-proxy-helpers.ts @@ -295,7 +295,51 @@ export function getOutputHeaders(response: Response) { } export function wrapInSafeNextResponse(response: Response) { - return new NextResponse(response.body, { + const source = response.body; + let owner: Response | undefined = response; + let reader: ReadableStreamDefaultReader | undefined; + let cancelled = false; + const release = () => { + if (!owner) return; + const active = reader; + reader = undefined; + owner = undefined; + active?.releaseLock(); + }; + const body = source + ? new ReadableStream({ + async pull(controller) { + reader ??= owner?.body?.getReader(); + const active = reader; + if (!active) return; + try { + const result = await active.read(); + if (cancelled) return; + if (result.done) { + controller.close(); + release(); + return; + } + controller.enqueue(result.value); + } catch (error) { + if (cancelled) return; + release(); + controller.error(error); + } + }, + async cancel(reason) { + cancelled = true; + try { + if (reader) await reader.cancel(reason); + else await source.cancel(reason); + } finally { + release(); + } + }, + }) + : null; + + return new NextResponse(body, { status: response.status, statusText: response.statusText, headers: getOutputHeaders(response), diff --git a/apps/web/src/lib/ai-gateway/wrap-safe-next-response.gc.ts b/apps/web/src/lib/ai-gateway/wrap-safe-next-response.gc.ts new file mode 100644 index 0000000000..69622a4c13 --- /dev/null +++ b/apps/web/src/lib/ai-gateway/wrap-safe-next-response.gc.ts @@ -0,0 +1,45 @@ +import assert from 'node:assert/strict'; +import { createServer } from 'node:http'; +import type { AddressInfo } from 'node:net'; + +async function main() { + assert.ok(global.gc); + + const { wrapInSafeNextResponse } = await import('./llm-proxy-helpers'); + const chunks = ['first', 'second', 'third']; + const server = createServer((_request, response) => { + response.writeHead(200, { 'content-type': 'text/plain' }); + response.write(chunks[0]); + setTimeout(() => response.write(chunks[1]), 25); + setTimeout(() => response.end(chunks[2]), 50); + }); + await new Promise(resolve => server.listen(0, '127.0.0.1', resolve)); + + try { + const address = server.address() as AddressInfo; + const wrapped = await (async () => { + const response = await fetch(`http://127.0.0.1:${address.port}`); + const siblingA = response.clone(); + const siblingB = response.clone(); + return { + response: wrapInSafeNextResponse(response), + siblings: [siblingA.text(), siblingB.text()], + }; + })(); + + for (let index = 0; index < 10; index++) { + global.gc(); + await new Promise(resolve => setImmediate(resolve)); + } + + const [body, ...siblings] = await Promise.all([wrapped.response.text(), ...wrapped.siblings]); + assert.deepEqual(siblings, [chunks.join(''), chunks.join('')]); + assert.equal(body, chunks.join('')); + } finally { + await new Promise((resolve, reject) => + server.close(error => (error ? reject(error) : resolve())) + ); + } +} + +void main();