Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions apps/web/src/lib/ai-gateway/llm-proxy-helpers.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { describe, it, expect, beforeEach } from '@jest/globals';
import { spawnSync } from 'node:child_process';
import { join } from 'node:path';
import type { MicrodollarUsageContext, MicrodollarUsageStats } from './processUsage.types';

// `countAndStoreEditUsage` schedules the usage write through `next/server`'s
Expand Down Expand Up @@ -32,8 +34,125 @@ import {
parseEmbeddingUsageFromResponse,
parseEditUsageFromResponse,
parseTranscriptionUsageFromResponse,
wrapInSafeNextResponse,
} from './llm-proxy-helpers';

describe('wrapInSafeNextResponse', () => {
it('keeps a fetched response alive until its returned body is consumed', () => {
const tsx = require.resolve('tsx/cli');
const fixture = join(__dirname, 'wrap-safe-next-response.gc.ts');
const result = spawnSync(process.execPath, ['--conditions=react-server', tsx, fixture], {
encoding: 'utf8',
env: {
...process.env,
NODE_ENV: 'test',
NODE_OPTIONS:
`${process.env.NODE_OPTIONS ?? ''} --expose-gc --conditions=react-server`.trim(),
UPSTASH_REDIS_REST_TOKEN: process.env.UPSTASH_REDIS_REST_TOKEN ?? 'test',
UPSTASH_REDIS_REST_URL: process.env.UPSTASH_REDIS_REST_URL ?? 'http://127.0.0.1',
},
timeout: 30_000,
});

expect({
status: result.status,
signal: result.signal,
error: result.error?.message,
stdout: result.stdout,
stderr: result.stderr,
}).toEqual({
status: 0,
signal: null,
error: undefined,
stdout: '',
stderr: '',
});
});

it('preserves exact bytes without locking the source eagerly', async () => {
const bytes = Uint8Array.from([0, 255, 1, 128, 13, 10]);
const source = new ReadableStream<Uint8Array>({
start(controller) {
controller.enqueue(bytes.subarray(0, 2));
controller.enqueue(bytes.subarray(2));
controller.close();
},
});
const wrapped = wrapInSafeNextResponse(new Response(source));

expect(source.locked).toBe(false);
expect(new Uint8Array(await wrapped.arrayBuffer())).toEqual(bytes);
expect(source.locked).toBe(false);
});

it('propagates source errors', async () => {
const error = new Error('source failed');
const source = new ReadableStream<Uint8Array>({
pull(controller) {
controller.error(error);
},
});
const wrapped = wrapInSafeNextResponse(new Response(source));

await expect(wrapped.text()).rejects.toBe(error);
expect(source.locked).toBe(false);
});

it('propagates cancellation to the source', async () => {
const cancelled = jest.fn();
const reason = new Error('consumer stopped');
const source = new ReadableStream<Uint8Array>({
pull(controller) {
controller.enqueue(Uint8Array.of(1));
},
cancel: cancelled,
});
const wrapped = wrapInSafeNextResponse(new Response(source));
const reader = wrapped.body?.getReader();

await reader?.read();
await reader?.cancel(reason);

expect(cancelled).toHaveBeenCalledWith(reason);
expect(source.locked).toBe(false);
});

it('propagates cancellation before the first pull', async () => {
const cancelled = jest.fn();
const reason = new Error('consumer stopped early');
const source = new ReadableStream<Uint8Array>({ cancel: cancelled });
const wrapped = wrapInSafeNextResponse(new Response(source));

await wrapped.body?.cancel(reason);

expect(cancelled).toHaveBeenCalledWith(reason);
expect(source.locked).toBe(false);
});

it('does not close an already cancelled stream when a pending pull settles', async () => {
const cancelled = jest.fn();
const reason = new Error('consumer stopped during pull');
const source = new ReadableStream<Uint8Array>({ cancel: cancelled });
const wrapped = wrapInSafeNextResponse(new Response(source));
const reader = wrapped.body?.getReader();
const pending = reader?.read();

await Promise.resolve();
await reader?.cancel(reason);

await expect(pending).resolves.toEqual({ done: true, value: undefined });
expect(cancelled).toHaveBeenCalledWith(reason);
expect(source.locked).toBe(false);
});

it('preserves bodyless responses', () => {
const wrapped = wrapInSafeNextResponse(new Response(null, { status: 204 }));

expect(wrapped.status).toBe(204);
expect(wrapped.body).toBeNull();
});
});

describe('checkOrganizationModelRestrictions', () => {
describe('enterprise plan - model deny list restrictions', () => {
it('should allow model when it is not in the deny list on enterprise plan', () => {
Expand Down
46 changes: 45 additions & 1 deletion apps/web/src/lib/ai-gateway/llm-proxy-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,51 @@ export function getOutputHeaders(response: Response) {
}

export function wrapInSafeNextResponse(response: Response) {
return new NextResponse(response.body, {
const source = response.body;
let owner: Response | undefined = response;
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
let cancelled = false;
const release = () => {
if (!owner) return;
const active = reader;
reader = undefined;
owner = undefined;
active?.releaseLock();
};
const body = source
? new ReadableStream<Uint8Array>({
async pull(controller) {
reader ??= owner?.body?.getReader();
const active = reader;
if (!active) return;
try {
const result = await active.read();
if (cancelled) return;
if (result.done) {
controller.close();
Comment thread
marius-kilocode marked this conversation as resolved.
release();
return;
}
controller.enqueue(result.value);
} catch (error) {
if (cancelled) return;
release();
controller.error(error);
}
},
async cancel(reason) {
cancelled = true;
try {
if (reader) await reader.cancel(reason);
else await source.cancel(reason);
} finally {
release();
}
},
})
: null;

return new NextResponse(body, {
status: response.status,
statusText: response.statusText,
headers: getOutputHeaders(response),
Expand Down
45 changes: 45 additions & 0 deletions apps/web/src/lib/ai-gateway/wrap-safe-next-response.gc.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import assert from 'node:assert/strict';
import { createServer } from 'node:http';
import type { AddressInfo } from 'node:net';

async function main() {
assert.ok(global.gc);

const { wrapInSafeNextResponse } = await import('./llm-proxy-helpers');
const chunks = ['first', 'second', 'third'];
const server = createServer((_request, response) => {
response.writeHead(200, { 'content-type': 'text/plain' });
response.write(chunks[0]);
setTimeout(() => response.write(chunks[1]), 25);
setTimeout(() => response.end(chunks[2]), 50);
});
await new Promise<void>(resolve => server.listen(0, '127.0.0.1', resolve));

try {
const address = server.address() as AddressInfo;
const wrapped = await (async () => {
const response = await fetch(`http://127.0.0.1:${address.port}`);
const siblingA = response.clone();
const siblingB = response.clone();
return {
response: wrapInSafeNextResponse(response),
siblings: [siblingA.text(), siblingB.text()],
};
})();

for (let index = 0; index < 10; index++) {
global.gc();
await new Promise(resolve => setImmediate(resolve));
}

const [body, ...siblings] = await Promise.all([wrapped.response.text(), ...wrapped.siblings]);
assert.deepEqual(siblings, [chunks.join(''), chunks.join('')]);
assert.equal(body, chunks.join(''));
} finally {
await new Promise<void>((resolve, reject) =>
server.close(error => (error ? reject(error) : resolve()))
);
}
}

void main();