Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion genkit-tools/cli/src/commands/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,28 @@
* limitations under the License.
*/

import { findProjectRoot, forceStderr } from '@genkit-ai/tools-common/utils';
import {
debugToFile,
findProjectRoot,
forceStderr,
} from '@genkit-ai/tools-common/utils';
import { Command } from 'commander';
import { startMcpServer } from '../mcp/server';

interface McpOptions {
projectRoot?: string;
debug?: boolean;
}

/** Command to run MCP server. */
export const mcp = new Command('mcp')
.option('--project-root [projectRoot]', 'Project root')
.option('-d, --debug', 'debug to file', false)
.description('run MCP stdio server (EXPERIMENTAL, subject to change)')
.action(async (options: McpOptions) => {
forceStderr();
if (options.debug) {
debugToFile();
}
await startMcpServer(options.projectRoot ?? (await findProjectRoot()));
});
8 changes: 6 additions & 2 deletions genkit-tools/cli/src/mcp/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ export function defineRuntimeTools(
{command: 'go', args: ['run', 'main.go']}
{command: 'npm', args: ['run', 'dev']}`,
inputSchema: {
command: z.string(),
args: z.array(z.string()),
command: z.string().describe('The command to run'),
args: z
.array(z.string())
.describe(
'List of command line arguments. IMPORTANT: This must be a JSON array of strings, not a single string.'
),
},
},
async ({ command, args }) => {
Expand Down
1 change: 1 addition & 0 deletions genkit-tools/cli/src/mcp/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { defineUsageGuideTool } from './usage';
import { McpRuntimeManager } from './util';

export async function startMcpServer(projectRoot: string) {
logger.info(`Starting MCP server in: ${projectRoot}`);
const server = new McpServer({
name: 'Genkit MCP',
version: '0.0.2',
Expand Down
3 changes: 2 additions & 1 deletion genkit-tools/cli/src/mcp/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ export class McpRuntimeManager {
const devManager = await startDevProcessManager(
this.projectRoot,
command,
args
args,
{ nonInteractive: true }
);
this.manager = devManager.manager;
return this.manager;
Expand Down
56 changes: 55 additions & 1 deletion genkit-tools/cli/src/utils/manager-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import {
import type { Status } from '@genkit-ai/tools-common';
import {
ProcessManager,
RuntimeEvent,
RuntimeManager,
type GenkitToolsError,
} from '@genkit-ai/tools-common/manager';
import { logger } from '@genkit-ai/tools-common/utils';
import * as crypto from 'crypto';
import getPort, { makeRange } from 'get-port';

/**
Expand Down Expand Up @@ -68,6 +70,7 @@ export async function startManager(

export interface DevProcessManagerOptions {
disableRealtimeTelemetry?: boolean;
nonInteractive?: boolean;
}

export async function startDevProcessManager(
Expand All @@ -78,9 +81,11 @@ export async function startDevProcessManager(
): Promise<{ manager: RuntimeManager; processPromise: Promise<void> }> {
const telemetryServerUrl = await resolveTelemetryServer(projectRoot);
const disableRealtimeTelemetry = options?.disableRealtimeTelemetry ?? false;
const runtimeId = crypto.randomUUID().substring(0, 8);
const envVars: Record<string, string> = {
GENKIT_TELEMETRY_SERVER: telemetryServerUrl,
GENKIT_ENV: 'dev',
GENKIT_RUNTIME_ID: runtimeId,
};
if (!disableRealtimeTelemetry) {
envVars.GENKIT_ENABLE_REALTIME_TELEMETRY = 'true';
Expand All @@ -93,10 +98,59 @@ export async function startDevProcessManager(
processManager,
disableRealtimeTelemetry,
});
const processPromise = processManager.start();
const processPromise = processManager.start({ ...options, cwd: projectRoot });

await waitForRuntime(manager, runtimeId, processPromise);

return { manager, processPromise };
}

/**
* Waits for the runtime with the given ID to register itself.
* Rejects if the process exits or if the timeout is reached.
*/
export async function waitForRuntime(
manager: RuntimeManager,
runtimeId: string,
processPromise: Promise<void>
): Promise<void> {
if (manager.getRuntimeById(runtimeId)) {
return;
}

await new Promise<void>((resolve, reject) => {
let timeoutId: NodeJS.Timeout;
let unsubscribe: () => void;

const cleanup = () => {
if (timeoutId) clearTimeout(timeoutId);
if (unsubscribe) unsubscribe();
};

timeoutId = setTimeout(() => {
cleanup();
reject(new Error('Timeout waiting for runtime to be ready'));
}, 30000);

unsubscribe = manager.onRuntimeEvent((event, runtime) => {
if (event === RuntimeEvent.ADD && runtime.id === runtimeId) {
cleanup();
resolve();
}
});

processPromise
.then(() => {
cleanup();
reject(new Error('Process exited before runtime was ready'));
})
.catch((err) => {
cleanup();
reject(err);
});
});
Comment on lines +121 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of waitForRuntime correctly handles waiting for the runtime, timeouts, and process exits. However, the manual management of state with the cleanup function can be simplified and made more robust by using Promise.race and a try...finally block for cleanup. This approach is more declarative and less prone to race conditions in the cleanup logic, improving the code's readability and maintainability.

  const TIMEOUT_MS = 30000;
  let unsubscribe: () => void;
  let timeoutId: NodeJS.Timeout;

  try {
    const runtimeAddedPromise = new Promise<void>((resolve) => {
      unsubscribe = manager.onRuntimeEvent((event, runtime) => {
        if (event === RuntimeEvent.ADD && runtime.id === runtimeId) {
          resolve();
        }
      });
    });

    const timeoutPromise = new Promise<void>((_, reject) => {
      timeoutId = setTimeout(
        () => reject(new Error('Timeout waiting for runtime to be ready')),
        TIMEOUT_MS
      );
    });

    const processExitedPromise = processPromise.then(
      () => Promise.reject(new Error('Process exited before runtime was ready')),
      (err) => Promise.reject(err)
    );

    await Promise.race([
      runtimeAddedPromise,
      timeoutPromise,
      processExitedPromise,
    ]);
  } finally {
    if (unsubscribe) unsubscribe();
    if (timeoutId) clearTimeout(timeoutId);
  }

}

/**
* Runs the given function with a runtime manager.
*/
Expand Down
117 changes: 117 additions & 0 deletions genkit-tools/cli/tests/commands/start_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { startServer } from '@genkit-ai/tools-common/server';
import {
afterEach,
beforeEach,
describe,
expect,
it,
jest,
} from '@jest/globals';
import { start } from '../../src/commands/start';
import * as managerUtils from '../../src/utils/manager-utils';

jest.mock('@genkit-ai/tools-common/server');
jest.mock('@genkit-ai/tools-common/utils', () => ({
findProjectRoot: jest.fn(() => Promise.resolve('/mock/root')),
logger: {
warn: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('get-port', () => ({
__esModule: true,
default: jest.fn(() => Promise.resolve(4000)),
makeRange: jest.fn(),
}));
jest.mock('open');

describe('start command', () => {
let startDevProcessManagerSpy: any;
let startManagerSpy: any;
let startServerSpy: any;

beforeEach(() => {
startDevProcessManagerSpy = jest
.spyOn(managerUtils, 'startDevProcessManager')
.mockResolvedValue({
manager: {} as any,
processPromise: Promise.resolve(),
});
startManagerSpy = jest
.spyOn(managerUtils, 'startManager')
.mockResolvedValue({} as any);
startServerSpy = startServer as unknown as jest.Mock;

// Reset args
start.args = [];
});

afterEach(() => {
jest.clearAllMocks();
});

it('should start dev process manager when args are provided', async () => {
await start.parseAsync(['node', 'genkit', 'run', 'app']);

expect(startDevProcessManagerSpy).toHaveBeenCalledWith(
'/mock/root',
'run',
['app'],
expect.objectContaining({ disableRealtimeTelemetry: undefined })
);
expect(startServerSpy).toHaveBeenCalled();
});

it('should start manager only when no args are provided', async () => {
start.parseAsync(['node', 'genkit']);

// Wait a tick for async operations
await new Promise((resolve) => setTimeout(resolve, 10));

expect(startManagerSpy).toHaveBeenCalledWith('/mock/root', true);
expect(startDevProcessManagerSpy).not.toHaveBeenCalled();
expect(startServerSpy).toHaveBeenCalled();
});

it('should not start server if --noui is provided', async () => {
// Cannot await, same reason as above
start.parseAsync(['node', 'genkit', '--noui']);

await new Promise((resolve) => setTimeout(resolve, 10));

expect(startServerSpy).not.toHaveBeenCalled();
});

it('should pass disableRealtimeTelemetry option', async () => {
await start.parseAsync([
'node',
'genkit',
'run',
'app',
'--disable-realtime-telemetry',
]);

expect(startDevProcessManagerSpy).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
expect.anything(),
expect.objectContaining({ disableRealtimeTelemetry: true })
);
});
});
100 changes: 100 additions & 0 deletions genkit-tools/cli/tests/utils/manager-utils_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { RuntimeEvent } from '@genkit-ai/tools-common/manager';
import { beforeEach, describe, expect, it, jest } from '@jest/globals';
import { waitForRuntime } from '../../src/utils/manager-utils';

describe('waitForRuntime', () => {
let mockManager: any;
let mockProcessPromise: Promise<void>;
let processReject: (reason?: any) => void;

beforeEach(() => {
mockManager = {
getRuntimeById: jest.fn(),
onRuntimeEvent: jest.fn(),
};
mockProcessPromise = new Promise((_, reject) => {
processReject = reject;
});
});

it('should resolve immediately if runtime is already present', async () => {
mockManager.getRuntimeById.mockReturnValue({});
await expect(
waitForRuntime(mockManager, 'test-id', mockProcessPromise)
).resolves.toBeUndefined();
});

it('should wait for runtime event and resolve', async () => {
mockManager.getRuntimeById.mockReturnValue(undefined);
let eventCallback: (event: RuntimeEvent, runtime: any) => void;

mockManager.onRuntimeEvent.mockImplementation((cb: any) => {
eventCallback = cb;
return jest.fn(); // unsubscribe
});

const waitPromise = waitForRuntime(
mockManager,
'test-id',
mockProcessPromise
);

// Simulate event
setTimeout(() => {
eventCallback(RuntimeEvent.ADD, { id: 'test-id' });
}, 10);

await expect(waitPromise).resolves.toBeUndefined();
});

it('should reject if process exits early', async () => {
mockManager.getRuntimeById.mockReturnValue(undefined);
mockManager.onRuntimeEvent.mockReturnValue(jest.fn());

const waitPromise = waitForRuntime(
mockManager,
'test-id',
mockProcessPromise
);

// Simulate process exit
processReject(new Error('Process exited'));

await expect(waitPromise).rejects.toThrow('Process exited');
});

it('should timeout if runtime never appears', async () => {
jest.useFakeTimers();
mockManager.getRuntimeById.mockReturnValue(undefined);
mockManager.onRuntimeEvent.mockReturnValue(jest.fn());

const waitPromise = waitForRuntime(
mockManager,
'test-id',
mockProcessPromise
);

jest.advanceTimersByTime(30000);

await expect(waitPromise).rejects.toThrow(
'Timeout waiting for runtime to be ready'
);
jest.useRealTimers();
});
});
Loading
Loading