Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions packages/ai-engine/src/oauth/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export {
type LoopbackCallback,
type LoopbackCallbackHandle,
type StartLoopbackCallbackServerOptions,
startLoopbackCallbackServer,
} from './loopback-callback-server';
export {
type BuildAuthorizationUrlInput,
buildAuthorizationUrl,
type ExchangeCodeInput,
exchangeCodeForToken,
generateOAuthState,
generatePkcePair,
type PkcePair,
type RefreshTokenInput,
refreshAccessToken,
type TokenExchangeResult,
} from './oauth-client';
134 changes: 134 additions & 0 deletions packages/ai-engine/src/oauth/loopback-callback-server.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import { describe, expect, it } from 'vitest';

import { startLoopbackCallbackServer } from './loopback-callback-server';

describe('startLoopbackCallbackServer', () => {
it('redirectUri は http://127.0.0.1:<port>/callback 形式 (port は OS 採番)', async () => {
const handle = await startLoopbackCallbackServer();
try {
expect(handle.redirectUri).toMatch(/^http:\/\/127\.0\.0\.1:\d+\/callback$/);
} finally {
await handle.close();
}
});

it('callback URL を叩くと awaitCallback が code/state で resolve する', async () => {
const handle = await startLoopbackCallbackServer();
try {
const promise = handle.awaitCallback();
// ブラウザの redirect 相当を fetch で再現
const callbackRes = await fetch(`${handle.redirectUri}?code=AAA&state=xyz`);
expect(callbackRes.status).toBe(200);
const got = await promise;
expect(got).toEqual({ code: 'AAA', state: 'xyz' });
} finally {
await handle.close();
}
});

it('error= 付き callback は awaitCallback を reject + 400 を返す', async () => {
const handle = await startLoopbackCallbackServer();
try {
const promise = handle.awaitCallback();
// unhandled rejection 抑止: rejection を先に観測予約しておかないと、
// Vitest が fetch との race で reject を unhandled として記録する。
promise.catch(() => {});
const res = await fetch(
`${handle.redirectUri}?error=access_denied&error_description=user%20canceled`,
);
expect(res.status).toBe(400);
await expect(promise).rejects.toThrow(/access_denied/);
} finally {
await handle.close();
}
});

it('code/state が無い callback は reject + 400', async () => {
const handle = await startLoopbackCallbackServer();
try {
const promise = handle.awaitCallback();
promise.catch(() => {});
const res = await fetch(`${handle.redirectUri}?code=onlyCode`);
expect(res.status).toBe(400);
await expect(promise).rejects.toThrow(/missing code or state/);
} finally {
await handle.close();
}
});

it('callback path 以外のリクエストは 404 (favicon 等のノイズ対策)', async () => {
const handle = await startLoopbackCallbackServer();
try {
const res = await fetch(`http://127.0.0.1:${new URL(handle.redirectUri).port}/favicon.ico`);
expect(res.status).toBe(404);
} finally {
await handle.close();
}
});

it('timeout で awaitCallback が reject する', async () => {
const handle = await startLoopbackCallbackServer();
try {
const start = Date.now();
await expect(handle.awaitCallback(50)).rejects.toThrow(/timeout/);
// 50ms ちょうどで止まる保証はないが、明らかに長すぎないことだけ確認
expect(Date.now() - start).toBeLessThan(2000);
} finally {
await handle.close();
}
});

it('close は冪等 (二度呼んでも throw しない)', async () => {
const handle = await startLoopbackCallbackServer();
await handle.close();
await expect(handle.close()).resolves.toBeUndefined();
});

it('close() で pending な awaitCallback が reject される (永久 pending リーク防止)', async () => {
const handle = await startLoopbackCallbackServer();
const promise = handle.awaitCallback();
promise.catch(() => {});
await handle.close();
await expect(promise).rejects.toThrow(/server closed/);
});

it('awaitCallback の多重呼び出しは throw (Promise 上書きでリーク防止)', async () => {
const handle = await startLoopbackCallbackServer();
try {
const first = handle.awaitCallback();
first.catch(() => {});
await expect(handle.awaitCallback()).rejects.toThrow(/can only be called once/);
} finally {
await handle.close();
}
});

it('close 後の awaitCallback は throw (callback は永遠に届かないため即時失敗)', async () => {
const handle = await startLoopbackCallbackServer();
await handle.close();
await expect(handle.awaitCallback()).rejects.toThrow(/already closed/);
});

it('preferredPort 指定時はその port で listen (port=0 は OS 採番)', async () => {
// 0 を指定したときと省略時は同じ挙動 (OS 採番)。
const a = await startLoopbackCallbackServer({ preferredPort: 0 });
try {
expect(Number(new URL(a.redirectUri).port)).toBeGreaterThan(0);
} finally {
await a.close();
}
});

it('path カスタマイズで redirect_uri が変わる', async () => {
const handle = await startLoopbackCallbackServer({ path: '/oauth/callback' });
try {
expect(handle.redirectUri).toMatch(/\/oauth\/callback$/);
const promise = handle.awaitCallback();
await fetch(`${handle.redirectUri}?code=A&state=S`);
const got = await promise;
expect(got.code).toBe('A');
} finally {
await handle.close();
}
});
});
169 changes: 169 additions & 0 deletions packages/ai-engine/src/oauth/loopback-callback-server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// ADR-0011 PR-E2: OAuth callback URL を loopback IP (127.0.0.1) で受ける一時 HTTP server。
//
// 設計判断:
// - port は OS 採番 (0 を渡す)。固定 port にすると複数フローや他プロセスとの衝突が発生する。
// - host は 127.0.0.1 固定。`localhost` だと IPv4/IPv6 どちらに bind するか実装依存で
// redirect_uri と一致しないことがある。
// - state 検証は本モジュールで行わない (orchestrator 側の責務)。受領した code/state を
// そのまま callback API で返す。
// - レスポンスは「タブを閉じてください」の最小 HTML。CSRF / XSS 対策で content type を
// text/plain でも良いが、UX を考えて最小 HTML にする。
// - timeout は呼び出し側で指定 (デフォルト 5 分)。timeout で reject されたあとも close()
// を呼べばリソース解放される。

import { createServer, type Server } from 'node:http';
import type { AddressInfo } from 'node:net';

export interface LoopbackCallback {
code: string;
state: string;
}

export interface LoopbackCallbackHandle {
// ブラウザに渡す redirect_uri (例: http://127.0.0.1:54321/callback)。
redirectUri: string;
// callback の到達を待つ。1 ハンドル 1 回だけ resolve する設計。
awaitCallback(timeoutMs?: number): Promise<LoopbackCallback>;
// server を閉じる。再呼び出しは no-op。
close(): Promise<void>;
}

export interface StartLoopbackCallbackServerOptions {
// callback path (default: '/callback')
path?: string;
// 希望 port (default: 0 = OS 採番)。固定 port が必要な provider 設定の場合のみ指定する。
preferredPort?: number;
}

export async function startLoopbackCallbackServer(
opts: StartLoopbackCallbackServerOptions = {},
): Promise<LoopbackCallbackHandle> {
const callbackPath = opts.path ?? '/callback';
const port = opts.preferredPort ?? 0;

let resolveCallback: ((cb: LoopbackCallback) => void) | null = null;
let rejectCallback: ((err: Error) => void) | null = null;
let timeoutHandle: NodeJS.Timeout | null = null;

const cleanup = () => {
if (timeoutHandle) {
clearTimeout(timeoutHandle);
timeoutHandle = null;
}
resolveCallback = null;
rejectCallback = null;
};

const server: Server = createServer((req, res) => {
// path が違うものは 404。`favicon.ico` 等のノイズ対策。
const url = new URL(req.url ?? '/', 'http://127.0.0.1');
if (url.pathname !== callbackPath) {
res.writeHead(404, { 'Content-Type': 'text/plain; charset=utf-8' });
res.end('Not Found');
return;
}

const code = url.searchParams.get('code');
const state = url.searchParams.get('state');
const error = url.searchParams.get('error');
const errorDescription = url.searchParams.get('error_description');

if (error) {
// Provider が OAuth error を返したケース (access_denied 等)。
res.writeHead(400, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(
`<!doctype html><meta charset="utf-8"><h1>認証エラー</h1><p>${escapeHtml(error)}: ${escapeHtml(errorDescription ?? '')}</p>`,
);
const reject = rejectCallback;
cleanup();
reject?.(new Error(`OAuth callback error: ${error} ${errorDescription ?? ''}`));
return;
}

if (!code || !state) {
res.writeHead(400, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(
'<!doctype html><meta charset="utf-8"><h1>認証エラー</h1><p>code または state が見つかりません。</p>',
);
const reject = rejectCallback;
cleanup();
reject?.(new Error('OAuth callback missing code or state'));
return;
}

// 成功レスポンス。ブラウザに「タブを閉じて Tally に戻ってください」と促す。
res.writeHead(200, { 'Content-Type': 'text/html; charset=utf-8' });
res.end(
'<!doctype html><meta charset="utf-8"><h1>認証完了</h1><p>このタブを閉じて Tally に戻ってください。</p>',
);
const resolve = resolveCallback;
cleanup();
resolve?.({ code, state });
});

await new Promise<void>((resolve, reject) => {
server.once('error', reject);
server.listen(port, '127.0.0.1', () => {
server.removeListener('error', reject);
resolve();
});
});

const addr = server.address() as AddressInfo;
const redirectUri = `http://127.0.0.1:${addr.port}${callbackPath}`;

let closed = false;
// 1 ハンドル 1 回だけ awaitCallback を許す。多重呼び出しは先行 Promise が
// 未解決のまま resolveCallback/rejectCallback を上書きされてリークするため
// 明示的に弾く (CR Major)。close 後の呼び出しも server が閉じている以上
// callback は届かないので即時失敗にする。
let awaitStarted = false;

return {
redirectUri,
async awaitCallback(timeoutMs = 5 * 60 * 1000): Promise<LoopbackCallback> {
if (closed) {
throw new Error('OAuth callback server is already closed');
}
if (awaitStarted) {
throw new Error('awaitCallback can only be called once per handle');
}
awaitStarted = true;
return new Promise<LoopbackCallback>((resolve, reject) => {
resolveCallback = resolve;
rejectCallback = reject;
if (timeoutMs > 0) {
timeoutHandle = setTimeout(() => {
const r = rejectCallback;
cleanup();
r?.(new Error(`OAuth callback timeout after ${timeoutMs}ms`));
}, timeoutMs);
}
});
Comment thread
coderabbitai[bot] marked this conversation as resolved.
},
async close(): Promise<void> {
if (closed) return;
closed = true;
// cleanup() より前に pending な awaitCallback を reject 発火させる
// (cleanup() は rejectCallback を null 化するだけで reject を呼ばない)。
// これを先にしないと、ユーザーキャンセル等で close() された際に
// 既存の awaitCallback Promise が永遠に settle せずリークする (CR Major)。
const reject = rejectCallback;
cleanup();
reject?.(new Error('OAuth callback server closed before callback received'));
await new Promise<void>((resolve, reject) => {
server.close((err) => (err ? reject(err) : resolve()));
});
},
};
}

// 最小 HTML エスケープ (provider が返す error 文言を表示する用)。
function escapeHtml(s: string): string {
return s
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;');
}
Loading