From fd7377f634c1985e14f6fad6f42933aafdd3b1f9 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 21:49:52 -0500 Subject: [PATCH 01/15] feat(ai): expose ThinkingLevel and ThinkingConfig.thinkingLevel --- .github/scripts/compare-types/configs/ai.ts | 10 ---- packages/ai/__tests__/exported-types.test.ts | 15 ++++++ .../ai/__tests__/generative-model.test.ts | 53 ++++++++++++++++++- packages/ai/lib/methods/generate-content.ts | 22 ++++++++ packages/ai/lib/types/enums.ts | 23 ++++++++ packages/ai/lib/types/requests.ts | 16 ++++++ 6 files changed, 128 insertions(+), 11 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 32fc0b1ae4..cff0b8dd88 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -151,11 +151,6 @@ const config: PackageConfig = { reason: 'Template tool unions are part of firebase-js-sdk template tooling that RN Firebase does not currently expose.', }, - { - name: 'ThinkingLevel', - reason: - 'RN Firebase supports thinking budgets but does not currently expose the JS SDK `ThinkingLevel` preset constants/type.', - }, ], extraInRN: [ { @@ -285,11 +280,6 @@ const config: PackageConfig = { reason: 'RN Firebase template Imagen model methods do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', }, - { - name: 'ThinkingConfig', - reason: - 'RN Firebase thinking config supports `thinkingBudget` and `includeThoughts`, but does not currently expose the JS SDK `thinkingLevel` preset field.', - }, { name: 'TypedSchema', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 51f7381bdc..c974ffc79c 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -80,6 +80,8 @@ import { Segment, StartChatParams, TextPart, + ThinkingConfig, + ThinkingLevel, ToolConfig, URLContext, URLContextMetadata, @@ -116,6 +118,14 @@ describe('AI', function () { expect(POSSIBLE_ROLES).toBeDefined(); }); + it('`ThinkingLevel` constant is properly exposed to end user', function () { + expect(ThinkingLevel).toBeDefined(); + expect(ThinkingLevel.MINIMAL).toBe('MINIMAL'); + expect(ThinkingLevel.LOW).toBe('LOW'); + expect(ThinkingLevel.MEDIUM).toBe('MEDIUM'); + expect(ThinkingLevel.HIGH).toBe('HIGH'); + }); + it('`AIError` class is properly exposed to end user', function () { expect(AIError).toBeDefined(); }); @@ -169,6 +179,11 @@ describe('AI', function () { expect(typeof _typeCheck).toBeDefined(); }); + it('`ThinkingConfig` type is properly exposed to end user', function () { + const _typeCheck: ThinkingConfig = { thinkingLevel: ThinkingLevel.LOW }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`TypedSchema` type is properly exposed to end user', function () { const _typeCheck: TypedSchema = {} as TypedSchema; expect(typeof _typeCheck).toBeDefined(); diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index 28946fff9f..9754657f51 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -17,7 +17,7 @@ import { describe, expect, it, jest } from '@jest/globals'; import { type ReactNativeFirebase } from '@react-native-firebase/app'; import { GenerativeModel } from '../lib/models/generative-model'; -import { AI, FunctionCallingMode } from '../lib/public-types'; +import { AI, FunctionCallingMode, ThinkingLevel } from '../lib/public-types'; import * as request from '../lib/requests/request'; import { BackendName, getMockResponse } from './test-utils/mock-response'; import { VertexAIBackend } from '../lib/backend'; @@ -124,6 +124,57 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('passes thinkingLevel through to generateContent', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + thinkingConfig: { + thinkingLevel: ThinkingLevel.LOW, + }, + }, + }); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + + await genModel.generateContent('hello'); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), + expect.stringContaining(`"thinkingLevel":"${ThinkingLevel.LOW}"`), + ); + makeRequestStub.mockRestore(); + }); + + it('throws when thinkingBudget and thinkingLevel are both set', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + thinkingConfig: { + thinkingBudget: 100, + thinkingLevel: ThinkingLevel.HIGH, + }, + }, + }); + const makeRequestStub = jest.spyOn(request, 'makeRequest'); + + await expect(genModel.generateContent('hello')).rejects.toThrow( + 'Cannot set both thinkingBudget and thinkingLevel in a config.', + ); + expect(makeRequestStub).not.toHaveBeenCalled(); + makeRequestStub.mockRestore(); + }); + it('passes text-only systemInstruction through to generateContent', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', diff --git a/packages/ai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts index 9d1f5c8a2f..4139e8c423 100644 --- a/packages/ai/lib/methods/generate-content.ts +++ b/packages/ai/lib/methods/generate-content.ts @@ -16,10 +16,12 @@ */ import { + AIErrorCode, GenerateContentRequest, GenerateContentResponse, GenerateContentResult, GenerateContentStreamResult, + GenerationConfig, RequestOptions, } from '../types'; import { Task, makeRequest, ServerPromptTemplateTask } from '../requests/request'; @@ -28,6 +30,24 @@ import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; import { BackendType } from '../public-types'; import * as GoogleAIMapper from '../googleai-mappers'; +import { AIError } from '../errors'; + +/** + * Client-side validation of common `GenerationConfig` pitfalls, in order + * to save the developer a wasted request. + */ +function validateGenerationConfig(generationConfig?: GenerationConfig): void { + if ( + // != allows for null and undefined. 0 is considered "set" by the model. + generationConfig?.thinkingConfig?.thinkingBudget != null && + generationConfig.thinkingConfig?.thinkingLevel + ) { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'Cannot set both thinkingBudget and thinkingLevel in a config.', + ); + } +} /** * Generates a content stream from a request body. @@ -44,6 +64,7 @@ export async function generateContentStream( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { + validateGenerationConfig(params.generationConfig); if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); } @@ -76,6 +97,7 @@ export async function generateContent( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { + validateGenerationConfig(params.generationConfig); if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); } diff --git a/packages/ai/lib/types/enums.ts b/packages/ai/lib/types/enums.ts index 92d41925d1..0175620f5f 100644 --- a/packages/ai/lib/types/enums.ts +++ b/packages/ai/lib/types/enums.ts @@ -340,3 +340,26 @@ export const Language = { * @beta */ export type Language = (typeof Language)[keyof typeof Language]; + +/** + * A preset that controls the model's "thinking" process. Use + * `ThinkingLevel.LOW` for faster responses on less complex tasks, and + * `ThinkingLevel.HIGH` for better reasoning on more complex tasks. + * + * @public + */ +export const ThinkingLevel = { + MINIMAL: 'MINIMAL', + LOW: 'LOW', + MEDIUM: 'MEDIUM', + HIGH: 'HIGH', +}; + +/** + * A preset that controls the model's "thinking" process. Use + * `ThinkingLevel.LOW` for faster responses on less complex tasks, and + * `ThinkingLevel.HIGH` for better reasoning on more complex tasks. + * + * @public + */ +export type ThinkingLevel = (typeof ThinkingLevel)[keyof typeof ThinkingLevel]; diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index f720bcb343..41b7c625e9 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -23,6 +23,7 @@ import { HarmBlockThreshold, HarmCategory, ResponseModality, + ThinkingLevel, } from './enums'; import { ObjectSchemaInterface, SchemaRequest } from './schema'; @@ -404,11 +405,26 @@ export interface ThinkingConfig { * If you don't specify a budget, the model will determine the appropriate amount * of thinking based on the complexity of the prompt. * + * The model will also error if `thinkingLevel` and `thinkingBudget` are + * both set. + * * An error will be thrown if you set a thinking budget for a model that does not support this * feature or if the specified budget is not within the model's supported range. */ thinkingBudget?: number; + /** + * If not specified, Gemini will use the model's default dynamic thinking level. + * + * @remarks + * Note: The model will error if `thinkingLevel` and `thinkingBudget` are + * both set. + * + * Important: Gemini 2.5 series models do not support thinking levels; use + * `thinkingBudget` to set a thinking budget instead. + */ + thinkingLevel?: ThinkingLevel; + /** * Whether to include "thought summaries" in the model's response. * From 62299d6575220e6022dc6446e4637d6a878a19d8 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 22:06:10 -0500 Subject: [PATCH 02/15] feat(ai): add AnyOfSchema support --- .github/scripts/compare-types/configs/ai.ts | 30 -------- packages/ai/__tests__/exported-types.test.ts | 5 ++ packages/ai/__tests__/schema-builder.test.ts | 76 +++++++++++++++++++- packages/ai/lib/requests/schema-builder.ts | 48 +++++++++++-- packages/ai/lib/types/schema.ts | 18 +++-- 5 files changed, 134 insertions(+), 43 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index cff0b8dd88..1f97c9c860 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,11 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'AnyOfSchema', - reason: - 'RN Firebase schema-builder does not currently expose the `anyOf` helper class, so union-schema composition is not part of the public RN AI API.', - }, { name: 'LiveServerGoingAwayNotice', reason: @@ -250,26 +245,6 @@ const config: PackageConfig = { reason: 'RN Firebase does not currently expose `maxSequentalFunctionCalls`, so its request options are limited to timeout and base URL.', }, - { - name: 'Schema', - reason: - 'RN Firebase schema-builder requires an explicit `type` and does not expose the JS SDK `anyOf` helper, so the public schema shape differs.', - }, - { - name: 'SchemaInterface', - reason: - 'RN Firebase schema interfaces require an explicit `type`, whereas the JS SDK declaration leaves `type` optional in the base interface.', - }, - { - name: 'SchemaRequest', - reason: - 'RN Firebase request-shaped schemas require an explicit `type`, whereas the JS SDK declaration leaves `type` optional.', - }, - { - name: 'SchemaShared', - reason: - 'RN Firebase shared schema typing omits the JS SDK `anyOf` property because `AnyOfSchema` is not currently part of the public RN API.', - }, { name: 'TemplateGenerativeModel', reason: @@ -280,11 +255,6 @@ const config: PackageConfig = { reason: 'RN Firebase template Imagen model methods do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', }, - { - name: 'TypedSchema', - reason: - 'RN Firebase typed schema unions do not currently include `AnyOfSchema`, so the exported union is smaller than the JS SDK version.', - }, { name: 'UsageMetadata', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index c974ffc79c..8f0c09e77b 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -18,6 +18,7 @@ import { describe, expect, it } from '@jest/globals'; import { // Runtime values (classes, functions, constants) + AnyOfSchema, BackendType, POSSIBLE_ROLES, AIError, @@ -118,6 +119,10 @@ describe('AI', function () { expect(POSSIBLE_ROLES).toBeDefined(); }); + it('`AnyOfSchema` class is properly exposed to end user', function () { + expect(AnyOfSchema).toBeDefined(); + }); + it('`ThinkingLevel` constant is properly exposed to end user', function () { expect(ThinkingLevel).toBeDefined(); expect(ThinkingLevel.MINIMAL).toBe('MINIMAL'); diff --git a/packages/ai/__tests__/schema-builder.test.ts b/packages/ai/__tests__/schema-builder.test.ts index 738bd17a21..52b4de7f25 100644 --- a/packages/ai/__tests__/schema-builder.test.ts +++ b/packages/ai/__tests__/schema-builder.test.ts @@ -15,7 +15,7 @@ * limitations under the License. */ import { describe, expect, it } from '@jest/globals'; -import { Schema } from '../lib/requests/schema-builder'; +import { AnyOfSchema, Schema } from '../lib/requests/schema-builder'; import { AIErrorCode } from '../lib/types'; describe('Schema builder', () => { @@ -129,6 +129,80 @@ describe('Schema builder', () => { }); }); + it('builds an anyOf schema', () => { + const schema = Schema.anyOf({ + description: 'string or object', + anyOf: [ + Schema.string(), + Schema.object({ + properties: { + count: Schema.integer(), + }, + }), + ], + }); + + expect(schema).toBeInstanceOf(AnyOfSchema); + expect(schema.toJSON()).toEqual({ + type: undefined, + description: 'string or object', + anyOf: [ + { + type: 'string', + nullable: false, + }, + { + type: 'object', + nullable: false, + properties: { + count: { + type: 'integer', + nullable: false, + }, + }, + required: ['count'], + }, + ], + nullable: false, + }); + }); + + it('serializes anyOf schemas nested inside object schemas', () => { + const schema = Schema.object({ + properties: { + value: Schema.anyOf({ + anyOf: [Schema.string(), Schema.number()], + }), + }, + }); + + expect(schema.toJSON()).toEqual({ + type: 'object', + nullable: false, + properties: { + value: { + type: undefined, + anyOf: [ + { + type: 'string', + nullable: false, + }, + { + type: 'number', + nullable: false, + }, + ], + nullable: false, + }, + }, + required: ['value'], + }); + }); + + it('throws if anyOf is empty', () => { + expect(() => Schema.anyOf({ anyOf: [] })).toThrow(AIErrorCode.INVALID_SCHEMA); + }); + it('builds layered schema - partially filled out', () => { const schema = Schema.array({ items: Schema.object({ diff --git a/packages/ai/lib/requests/schema-builder.ts b/packages/ai/lib/requests/schema-builder.ts index 21c5605cb7..6bf9676820 100644 --- a/packages/ai/lib/requests/schema-builder.ts +++ b/packages/ai/lib/requests/schema-builder.ts @@ -34,10 +34,11 @@ import { */ export abstract class Schema implements SchemaInterface { /** - * Optional. The type of the property. {@link - * SchemaType}. + * Optional. The type of the property. + * This can only be undefined when using `anyOf` schemas, which do not have an + * explicit type in the {@link https://swagger.io/docs/specification/v3_0/data-models/data-types/#any-type | OpenAPI specification}. */ - type: SchemaType; + type?: SchemaType; /** Optional. The format of the property. * Supported formats:
*
    @@ -69,7 +70,6 @@ export abstract class Schema implements SchemaInterface { for (const paramKey in schemaParams) { this[paramKey] = schemaParams[paramKey]; } - // Ensure these are explicitly set to avoid TS errors. this.type = schemaParams.type; this.nullable = schemaParams.hasOwnProperty('nullable') ? !!schemaParams.nullable : false; } @@ -80,7 +80,7 @@ export abstract class Schema implements SchemaInterface { * @internal */ toJSON(): SchemaRequest { - const obj: { type: SchemaType; [key: string]: unknown } = { + const obj: { type?: SchemaType; [key: string]: unknown } = { type: this.type, }; for (const prop in this) { @@ -127,6 +127,10 @@ export abstract class Schema implements SchemaInterface { static boolean(booleanParams?: SchemaParams): BooleanSchema { return new BooleanSchema(booleanParams); } + + static anyOf(anyOfParams: SchemaParams & { anyOf: TypedSchema[] }): AnyOfSchema { + return new AnyOfSchema(anyOfParams); + } } /** @@ -139,7 +143,8 @@ export type TypedSchema = | StringSchema | BooleanSchema | ObjectSchema - | ArraySchema; + | ArraySchema + | AnyOfSchema; /** * Schema class for "integer" types. @@ -285,3 +290,34 @@ export class ObjectSchema extends Schema { return obj as SchemaRequest; } } + +/** + * Schema class representing a value that can conform to any of the provided sub-schemas. This is + * useful when a field can accept multiple distinct types or structures. + * @public + */ +export class AnyOfSchema extends Schema { + anyOf: TypedSchema[]; + + constructor(schemaParams: SchemaParams & { anyOf: TypedSchema[] }) { + if (schemaParams.anyOf.length === 0) { + throw new AIError(AIErrorCode.INVALID_SCHEMA, "The 'anyOf' array must not be empty."); + } + super({ + ...schemaParams, + type: undefined, // anyOf schemas do not have an explicit type + }); + this.anyOf = schemaParams.anyOf; + } + + /** + * @internal + */ + toJSON(): SchemaRequest { + const obj = super.toJSON(); + if (this.anyOf && Array.isArray(this.anyOf)) { + obj.anyOf = this.anyOf.map(s => s.toJSON()); + } + return obj; + } +} diff --git a/packages/ai/lib/types/schema.ts b/packages/ai/lib/types/schema.ts index 378134df16..15e4b1e527 100644 --- a/packages/ai/lib/types/schema.ts +++ b/packages/ai/lib/types/schema.ts @@ -42,6 +42,12 @@ export enum SchemaType { * @public */ export interface SchemaShared { + /** + * An array of {@link Schema}. The generated data must be valid against any of the schemas + * listed in this array. This allows specifying multiple possible structures or types for a + * single field. + */ + anyOf?: T[]; /** Optional. The format of the property. * When using the Gemini Developer API ({@link GoogleAIBackend}), this must be either `'enum'` or * `'date-time'`, otherwise requests will fail. @@ -93,10 +99,10 @@ export interface SchemaParams extends SchemaShared {} */ export interface SchemaRequest extends SchemaShared { /** - * The type of the property. {@link - * SchemaType}. + * The type of the property. this can only be undefined when using `anyOf` schemas, + * which do not have an explicit type in the {@link https://swagger.io/docs/specification/v3_0/data-models/data-types/#any-type | OpenAPI specification }. */ - type: SchemaType; + type?: SchemaType; /** Optional. Array of required property. */ required?: string[]; } @@ -107,10 +113,10 @@ export interface SchemaRequest extends SchemaShared { */ export interface SchemaInterface extends SchemaShared { /** - * The type of the property. {@link - * SchemaType}. + * The type of the property. this can only be undefined when using `anyof` schemas, + * which do not have an explicit type in the {@link https://swagger.io/docs/specification/v3_0/data-models/data-types/#any-type | OpenAPI Specification}. */ - type: SchemaType; + type?: SchemaType; } /** From 7b8750050c8f93b0e13be53c2f4ddac975ac2e32 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 22:20:17 -0500 Subject: [PATCH 03/15] feat(ai): expose LiveServerGoingAwayNotice --- .github/scripts/compare-types/configs/ai.ts | 15 --------- packages/ai/__tests__/exported-types.test.ts | 14 ++++++++ packages/ai/__tests__/live-session.test.ts | 34 +++++++++++++++++++- packages/ai/lib/methods/live-session.ts | 22 ++++++++++++- packages/ai/lib/types/responses.ts | 14 ++++++++ 5 files changed, 82 insertions(+), 17 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 1f97c9c860..5e044c3831 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,11 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'LiveServerGoingAwayNotice', - reason: - 'RN Firebase live sessions do not currently surface the server `goingAwayNotice` message type in the public API.', - }, { name: 'ObjectSchemaRequest', reason: @@ -230,16 +225,6 @@ const config: PackageConfig = { reason: 'RN Firebase Imagen model requests do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', }, - { - name: 'LiveResponseType', - reason: - 'RN Firebase live response typing omits `GOING_AWAY_NOTICE` because `LiveServerGoingAwayNotice` is not currently surfaced in the public API.', - }, - { - name: 'LiveSession', - reason: - 'RN Firebase live sessions do not currently expose `LiveServerGoingAwayNotice` from `receive()`, so the response union is smaller than the JS SDK.', - }, { name: 'RequestOptions', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 8f0c09e77b..97d73ed8d1 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -65,6 +65,7 @@ import { GroundingAttribution, GroundingMetadata, InlineDataPart, + LiveServerGoingAwayNotice, ModalityTokenCount, ModelParams, ObjectSchemaInterface, @@ -102,6 +103,7 @@ import { HarmCategory, HarmProbability, HarmSeverity, + LiveResponseType, Modality, SchemaType, } from '../lib'; @@ -527,5 +529,17 @@ describe('AI', function () { 'URL_RETRIEVAL_STATUS_UNSPECIFIED', ); }); + + it('`LiveResponseType.GOING_AWAY_NOTICE` constant is properly exposed to end user', function () { + expect(LiveResponseType.GOING_AWAY_NOTICE).toBe('goingAwayNotice'); + }); + + it('`LiveServerGoingAwayNotice` type is properly exposed to end user', function () { + const _typeCheck: LiveServerGoingAwayNotice = { + type: 'goingAwayNotice', + timeLeft: 10, + }; + expect(typeof _typeCheck).toBe('object'); + }); }); }); diff --git a/packages/ai/__tests__/live-session.test.ts b/packages/ai/__tests__/live-session.test.ts index 1a638e38a4..2ac7d97856 100644 --- a/packages/ai/__tests__/live-session.test.ts +++ b/packages/ai/__tests__/live-session.test.ts @@ -20,6 +20,7 @@ import { FunctionResponse, LiveResponseType, LiveServerContent, + LiveServerGoingAwayNotice, LiveServerToolCall, LiveServerToolCallCancellation, } from '../lib/types'; @@ -232,6 +233,9 @@ describe('LiveSession', function () { mockHandler.simulateServerMessage({ toolCallCancellation: { functionIds: ['123'] }, }); + mockHandler.simulateServerMessage({ + goAway: { timeLeft: '10.500s' }, + }); mockHandler.simulateServerMessage({ serverContent: { turnComplete: true }, }); @@ -239,7 +243,7 @@ describe('LiveSession', function () { mockHandler.endStream(); const responses = await receivePromise; - expect(responses).toHaveLength(4); + expect(responses).toHaveLength(5); expect(responses[0]).toEqual({ type: LiveResponseType.SERVER_CONTENT, modelTurn: { parts: [{ text: 'response 1' }] }, @@ -252,6 +256,34 @@ describe('LiveSession', function () { type: LiveResponseType.TOOL_CALL_CANCELLATION, functionIds: ['123'], } as LiveServerToolCallCancellation); + expect(responses[3]).toEqual({ + type: LiveResponseType.GOING_AWAY_NOTICE, + timeLeft: 10.5, + } as LiveServerGoingAwayNotice); + }); + + it('should default malformed goAway timeLeft values to zero', async function () { + const receivePromise = (async () => { + const responses = []; + for await (const response of session.receive()) { + responses.push(response); + } + return responses; + })(); + + mockHandler.simulateServerMessage({ + goAway: { timeLeft: 'invalid' }, + }); + await new Promise(r => setTimeout(() => r(), 10)); // Wait for the listener to process messages + mockHandler.endStream(); + + const responses = await receivePromise; + expect(responses).toEqual([ + { + type: LiveResponseType.GOING_AWAY_NOTICE, + timeLeft: 0, + }, + ]); }); it('should log a warning and skip messages that are not objects', async function () { diff --git a/packages/ai/lib/methods/live-session.ts b/packages/ai/lib/methods/live-session.ts index f3cd2d099d..2e766b7c2e 100644 --- a/packages/ai/lib/methods/live-session.ts +++ b/packages/ai/lib/methods/live-session.ts @@ -21,6 +21,7 @@ import { GenerativeContentBlob, LiveResponseType, LiveServerContent, + LiveServerGoingAwayNotice, LiveServerToolCall, LiveServerToolCallCancellation, Part, @@ -223,7 +224,10 @@ export class LiveSession { * @beta */ async *receive(): AsyncGenerator< - LiveServerContent | LiveServerToolCall | LiveServerToolCallCancellation + | LiveServerContent + | LiveServerToolCall + | LiveServerToolCallCancellation + | LiveServerGoingAwayNotice > { if (this.isClosed) { throw new AIError( @@ -252,6 +256,12 @@ export class LiveSession { } ).toolCallCancellation, } as LiveServerToolCallCancellation; + } else if ('goAway' in message) { + const notice = (message as { goAway: { timeLeft?: string } }).goAway; + yield { + type: LiveResponseType.GOING_AWAY_NOTICE, + timeLeft: parseDuration(notice.timeLeft), + } as LiveServerGoingAwayNotice; } else { logger.warn( `Received an unknown message type from the server: ${JSON.stringify(message)}`, @@ -342,3 +352,13 @@ export class LiveSession { } } } + +/** + * Parses a duration string (e.g. "3.000000001s") into a number of seconds. + */ +function parseDuration(duration?: string): number { + if (!duration || !duration.endsWith('s')) { + return 0; + } + return Number(duration.slice(0, -1)); +} diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts index 378eb5dd8e..2c88b78d5e 100644 --- a/packages/ai/lib/types/responses.ts +++ b/packages/ai/lib/types/responses.ts @@ -591,6 +591,19 @@ export interface LiveServerToolCallCancellation { functionIds: string[]; } +/** + * Notification that the server will not be able to service the client soon. + * + * @beta + */ +export interface LiveServerGoingAwayNotice { + type: 'goingAwayNotice'; + /** + * The remaining time (in seconds) before the connection will be terminated. + */ + timeLeft: number; +} + /** * The types of responses that can be returned by {@link LiveSession.receive}. * @@ -600,6 +613,7 @@ export const LiveResponseType = { SERVER_CONTENT: 'serverContent', TOOL_CALL: 'toolCall', TOOL_CALL_CANCELLATION: 'toolCallCancellation', + GOING_AWAY_NOTICE: 'goingAwayNotice', }; /** From 86b41773012ac25952ea259a47a0817081bfd0f2 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 22:26:14 -0500 Subject: [PATCH 04/15] feat(ai): expose ObjectSchemaRequest --- .github/scripts/compare-types/configs/ai.ts | 5 ---- packages/ai/__tests__/exported-types.test.ts | 26 ++++++++++++++++++++ packages/ai/lib/types/requests.ts | 4 +-- packages/ai/lib/types/schema.ts | 17 +++++++++++++ 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 5e044c3831..af94214218 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,11 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'ObjectSchemaRequest', - reason: - 'RN Firebase exposes `ObjectSchemaInterface` for schema helper typing, but does not separately export the raw request-shape `ObjectSchemaRequest` type.', - }, { name: 'SingleRequestOptions', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 97d73ed8d1..6f14dad581 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -69,6 +69,7 @@ import { ModalityTokenCount, ModelParams, ObjectSchemaInterface, + ObjectSchemaRequest, PromptFeedback, RequestOptions, RetrievedContextAttribution, @@ -361,6 +362,31 @@ describe('AI', function () { expect(typeof _typeCheck).toBeDefined(); }); + it('`ObjectSchemaRequest` type is properly exposed to end user', function () { + const _typeCheck: ObjectSchemaRequest = { + type: 'object', + properties: {}, + }; + expect(typeof _typeCheck).toBe('object'); + }); + + it('`FunctionDeclaration.parameters` accepts ObjectSchemaRequest', function () { + const _typeCheck: FunctionDeclaration = { + name: 'getWeather', + description: 'Gets weather for a city.', + parameters: { + type: 'object', + properties: { + city: { + type: SchemaType.STRING, + }, + }, + required: ['city'], + }, + }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`PromptFeedback` type is properly exposed to end user', function () { const _typeCheck: PromptFeedback = {} as PromptFeedback; expect(typeof _typeCheck).toBeDefined(); diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 41b7c625e9..261ca52eb0 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -25,7 +25,7 @@ import { ResponseModality, ThinkingLevel, } from './enums'; -import { ObjectSchemaInterface, SchemaRequest } from './schema'; +import { ObjectSchemaInterface, ObjectSchemaRequest, SchemaRequest } from './schema'; /** * Base parameters for a number of methods. @@ -279,7 +279,7 @@ export interface FunctionDeclaration { * format. Reflects the Open API 3.03 Parameter Object. Parameter names are * case-sensitive. For a function with no parameters, this can be left unset. */ - parameters?: ObjectSchemaInterface; + parameters?: ObjectSchemaInterface | ObjectSchemaRequest; } /** diff --git a/packages/ai/lib/types/schema.ts b/packages/ai/lib/types/schema.ts index 15e4b1e527..0fd37204f4 100644 --- a/packages/ai/lib/types/schema.ts +++ b/packages/ai/lib/types/schema.ts @@ -107,6 +107,23 @@ export interface SchemaRequest extends SchemaShared { required?: string[]; } +/** + * Interface for JSON parameters in a schema of {@link SchemaType.OBJECT} + * when not using the `Schema.object()` helper. + * @public + */ +export interface ObjectSchemaRequest extends Omit { + type: 'object'; + /** + * This is not a property accepted in the final request to the backend, but is + * a client-side convenience property that is only usable by constructing + * a schema through the `Schema.object()` helper method. Populating this + * property will cause response errors if the object is not wrapped with + * `Schema.object()`. + */ + optionalProperties?: never; +} + /** * Interface for {@link Schema} class. * @public From 7096ed8985abb69a656e2421ce4b2b6fa00632fe Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 22:46:07 -0500 Subject: [PATCH 05/15] feat(ai): support per-call request options --- .github/scripts/compare-types/configs/ai.ts | 26 +------ packages/ai/__tests__/chat-session.test.ts | 48 ++++++++++++ packages/ai/__tests__/exported-types.test.ts | 11 +++ .../ai/__tests__/generative-model.test.ts | 76 ++++++++++++++++++- packages/ai/__tests__/imagen-model.test.ts | 37 +++++++++ packages/ai/__tests__/request.test.ts | 65 ++++++++++++++++ .../template-generative-model.test.ts | 54 +++++++++++++ .../__tests__/template-imagen-model.test.ts | 40 ++++++++++ packages/ai/lib/methods/chat-session.ts | 17 ++++- packages/ai/lib/methods/count-tokens.ts | 6 +- packages/ai/lib/methods/generate-content.ts | 18 ++--- packages/ai/lib/models/generative-model.ts | 16 +++- packages/ai/lib/models/imagen-model.ts | 12 ++- .../lib/models/template-generative-model.ts | 9 ++- .../ai/lib/models/template-imagen-model.ts | 6 +- packages/ai/lib/requests/request-options.ts | 31 ++++++++ packages/ai/lib/requests/request.ts | 50 +++++++++--- packages/ai/lib/types/requests.ts | 15 ++++ 18 files changed, 477 insertions(+), 60 deletions(-) create mode 100644 packages/ai/lib/requests/request-options.ts diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index af94214218..7d1cd40f09 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,11 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'SingleRequestOptions', - reason: - 'RN Firebase does not currently expose per-call request overrides such as `AbortSignal`; requests are configured via model-level `RequestOptions` only.', - }, { name: 'ChatSessionBase', reason: @@ -193,7 +188,7 @@ const config: PackageConfig = { { name: 'ChatSession', reason: - 'RN Firebase chat sessions do not currently accept per-call `SingleRequestOptions`, so `sendMessage` and `sendMessageStream` expose fewer parameters.', + 'RN Firebase chat sessions expose `getHistory()` directly on the concrete class while the firebase-js-sdk inherits it from `ChatSessionBase`.', }, { name: 'FunctionDeclaration', @@ -210,30 +205,15 @@ const config: PackageConfig = { reason: 'RN Firebase does not currently expose the JS SDK `responseJsonSchema` generation config field.', }, - { - name: 'GenerativeModel', - reason: - 'RN Firebase generative model methods do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', - }, - { - name: 'ImagenModel', - reason: - 'RN Firebase Imagen model requests do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', - }, { name: 'RequestOptions', reason: - 'RN Firebase does not currently expose `maxSequentalFunctionCalls`, so its request options are limited to timeout and base URL.', + 'RN Firebase does not currently expose `maxSequentialFunctionCalls`, so its request options are limited to timeout and base URL.', }, { name: 'TemplateGenerativeModel', reason: - 'RN Firebase template generative model methods do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', - }, - { - name: 'TemplateImagenModel', - reason: - 'RN Firebase template Imagen model methods do not currently accept per-call `SingleRequestOptions`, so request overrides are limited to model-level `RequestOptions`.', + 'RN Firebase template generative models do not currently expose `startChat`, so template chat sessions remain absent.', }, { name: 'UsageMetadata', diff --git a/packages/ai/__tests__/chat-session.test.ts b/packages/ai/__tests__/chat-session.test.ts index 10b025c62a..8f06b0d033 100644 --- a/packages/ai/__tests__/chat-session.test.ts +++ b/packages/ai/__tests__/chat-session.test.ts @@ -57,6 +57,29 @@ describe('ChatSession', () => { requestOptions, ); }); + + it('merges per-call request options over session request options', async () => { + const controller = new AbortController(); + const generateContentStub = jest + .spyOn(generateContentMethods, 'generateContent') + .mockResolvedValue({ response: { candidates: [] } } as any); + const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); + + await chatSession.sendMessage('hello', { + timeout: 2000, + signal: controller.signal, + }); + + expect(generateContentStub).toHaveBeenCalledWith( + fakeApiSettings, + 'a-model', + expect.anything(), + { + timeout: 2000, + signal: controller.signal, + }, + ); + }); }); describe('sendMessageStream()', () => { @@ -81,6 +104,31 @@ describe('ChatSession', () => { jest.useRealTimers(); }); + it('merges per-call request options over session request options', async () => { + const controller = new AbortController(); + const generateContentStreamStub = jest + .spyOn(generateContentMethods, 'generateContentStream') + .mockResolvedValue({ + response: Promise.resolve({ candidates: [] }), + } as unknown as GenerateContentStreamResult); + const chatSession = new ChatSession(fakeApiSettings, 'a-model', {}, requestOptions); + + await chatSession.sendMessageStream('hello', { + timeout: 2000, + signal: controller.signal, + }); + + expect(generateContentStreamStub).toHaveBeenCalledWith( + fakeApiSettings, + 'a-model', + expect.anything(), + { + timeout: 2000, + signal: controller.signal, + }, + ); + }); + it('downstream sendPromise errors should log but not throw', async () => { const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); // make response undefined so that response.candidates errors diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 6f14dad581..d96ad6b27a 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -81,6 +81,7 @@ import { SchemaShared, SearchEntrypoint, Segment, + SingleRequestOptions, StartChatParams, TextPart, ThinkingConfig, @@ -397,6 +398,16 @@ describe('AI', function () { expect(typeof _typeCheck).toBeDefined(); }); + it('`SingleRequestOptions` type is properly exposed to end user', function () { + const controller = new AbortController(); + const _typeCheck: SingleRequestOptions = { + timeout: 1000, + baseUrl: 'https://example.com', + signal: controller.signal, + }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`RetrievedContextAttribution` type is properly exposed to end user', function () { const _typeCheck: RetrievedContextAttribution = {} as RetrievedContextAttribution; expect(typeof _typeCheck).toBeDefined(); diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index 9754657f51..833059ca5a 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -124,6 +124,44 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('merges per-call request options over model request options for generateContent', async () => { + const controller = new AbortController(); + const genModel = new GenerativeModel( + fakeAI, + { + model: 'my-model', + }, + { + timeout: 1000, + baseUrl: 'https://model.example.com', + }, + ); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + + await genModel.generateContent('hello', { + timeout: 2000, + signal: controller.signal, + }); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + requestOptions: { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + }), + expect.any(String), + ); + makeRequestStub.mockRestore(); + }); + it('passes thinkingLevel through to generateContent', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', @@ -175,6 +213,42 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('merges per-call request options over model request options for countTokens', async () => { + const controller = new AbortController(); + const genModel = new GenerativeModel( + fakeAI, + { + model: 'my-model', + }, + { + timeout: 1000, + baseUrl: 'https://model.example.com', + }, + ); + const makeRequestStub = jest.spyOn(request, 'makeRequest').mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ totalTokens: 1 }), + } as Response); + + await genModel.countTokens('hello', { + timeout: 2000, + signal: controller.signal, + }); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + task: request.Task.COUNT_TOKENS, + requestOptions: { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + }), + expect.any(String), + ); + makeRequestStub.mockRestore(); + }); + it('passes text-only systemInstruction through to generateContent', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', @@ -433,7 +507,7 @@ describe('GenerativeModel', () => { task: request.Task.COUNT_TOKENS, apiSettings: expect.anything(), stream: false, - requestOptions: undefined, + requestOptions: {}, }), expect.stringContaining('hello'), ); diff --git a/packages/ai/__tests__/imagen-model.test.ts b/packages/ai/__tests__/imagen-model.test.ts index 9198581364..29c23842c2 100644 --- a/packages/ai/__tests__/imagen-model.test.ts +++ b/packages/ai/__tests__/imagen-model.test.ts @@ -141,6 +141,43 @@ describe('ImagenModel', () => { ); }); + it('generateImages merges per-call request options over model request options', async () => { + const controller = new AbortController(); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-generate-images-base64.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const imagenModel = new ImagenModel( + fakeAI, + { + model: 'my-model', + }, + { + timeout: 1000, + baseUrl: 'https://model.example.com', + }, + ); + + await imagenModel.generateImages('A toy boat.', { + timeout: 2000, + signal: controller.signal, + }); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + requestOptions: { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + }), + expect.any(String), + ); + }); + it('throws if prompt blocked', async () => { const mockResponse = getMockResponse( BackendName.VertexAI, diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts index d00fbabc12..f7d59d446e 100644 --- a/packages/ai/__tests__/request.test.ts +++ b/packages/ai/__tests__/request.test.ts @@ -31,6 +31,15 @@ const fakeApiSettings: ApiSettings = { backend: new VertexAIBackend(), }; +function createAbortErrorForTest(reason?: unknown): Error { + if (typeof DOMException !== 'undefined') { + return new DOMException(reason == null ? 'Aborted' : String(reason), 'AbortError'); + } + const error = new Error(reason == null ? 'Aborted' : String(reason)); + error.name = 'AbortError'; + return error; +} + describe('request methods', () => { afterEach(() => { jest.restoreAllMocks(); // Use Jest's restoreAllMocks @@ -244,6 +253,62 @@ describe('request methods', () => { }); describe('makeRequest', () => { + it('throws an AbortError without fetching if the external signal is already aborted', async () => { + const fetchMock = jest.spyOn(globalThis, 'fetch'); + const controller = new AbortController(); + (controller.abort as (reason?: unknown) => void)('user cancelled'); + + await expect( + makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { + signal: controller.signal, + }, + }, + '', + ), + ).rejects.toMatchObject({ + name: 'AbortError', + message: 'user cancelled', + }); + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it('aborts the fetch signal when the external signal aborts', async () => { + const controller = new AbortController(); + const fetchMock = jest.spyOn(globalThis, 'fetch').mockImplementation((_url, init) => { + return new Promise((_resolve, reject) => { + (init?.signal as AbortSignal).addEventListener('abort', () => { + reject(createAbortErrorForTest()); + }); + (controller.abort as (reason?: unknown) => void)('cancelled during fetch'); + }); + }); + + await expect( + makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { + signal: controller.signal, + }, + }, + '', + ), + ).rejects.toMatchObject({ + name: 'AbortError', + message: expect.any(String), + }); + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + it('no error', async () => { const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue({ ok: true, diff --git a/packages/ai/__tests__/template-generative-model.test.ts b/packages/ai/__tests__/template-generative-model.test.ts index 1f1ddee409..95476efd3e 100644 --- a/packages/ai/__tests__/template-generative-model.test.ts +++ b/packages/ai/__tests__/template-generative-model.test.ts @@ -70,6 +70,33 @@ describe('TemplateGenerativeModel', function () { { timeout: 5000 }, ); }); + + it('should merge per-call request options over model request options', async function () { + const controller = new AbortController(); + const templateGenerateContentSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContent') + .mockResolvedValue({} as any); + const model = new TemplateGenerativeModel(fakeAI, { + timeout: 5000, + baseUrl: 'https://model.example.com', + }); + + await model.generateContent(TEMPLATE_ID, TEMPLATE_VARS, { + timeout: 2000, + signal: controller.signal, + }); + + expect(templateGenerateContentSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + ); + }); }); describe('generateContentStream', function () { @@ -89,5 +116,32 @@ describe('TemplateGenerativeModel', function () { { timeout: 5000 }, ); }); + + it('should merge per-call request options over model request options', async function () { + const controller = new AbortController(); + const templateGenerateContentStreamSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContentStream') + .mockResolvedValue({} as any); + const model = new TemplateGenerativeModel(fakeAI, { + timeout: 5000, + baseUrl: 'https://model.example.com', + }); + + await model.generateContentStream(TEMPLATE_ID, TEMPLATE_VARS, { + timeout: 2000, + signal: controller.signal, + }); + + expect(templateGenerateContentStreamSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + ); + }); }); }); diff --git a/packages/ai/__tests__/template-imagen-model.test.ts b/packages/ai/__tests__/template-imagen-model.test.ts index e4ff6e855b..67387f0494 100644 --- a/packages/ai/__tests__/template-imagen-model.test.ts +++ b/packages/ai/__tests__/template-imagen-model.test.ts @@ -86,6 +86,46 @@ describe('TemplateImagenModel', function () { ); }); + it('should merge per-call request options over model request options', async function () { + const controller = new AbortController(); + const makeRequestSpy = jest.spyOn(request, 'makeRequest').mockResolvedValue({ + json: () => + Promise.resolve({ + predictions: [ + { + bytesBase64Encoded: + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + mimeType: 'image/png', + }, + ], + }), + } as Response); + const model = new TemplateImagenModel(fakeAI, { + timeout: 5000, + baseUrl: 'https://model.example.com', + }); + + await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS, { + timeout: 2000, + signal: controller.signal, + }); + + expect(makeRequestSpy).toHaveBeenCalledWith( + { + task: ServerPromptTemplateTask.TEMPLATE_PREDICT, + templateId: TEMPLATE_ID, + apiSettings: model._apiSettings, + stream: false, + requestOptions: { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + }, + JSON.stringify({ inputs: TEMPLATE_VARS }), + ); + }); + it('should return the result of handlePredictResponse', async function () { const mockPrediction = { bytesBase64Encoded: diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index 6bbb6f526c..7f367e81ba 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -22,6 +22,7 @@ import { GenerateContentStreamResult, Part, RequestOptions, + SingleRequestOptions, StartChatParams, EnhancedGenerateContentResponse, } from '../types'; @@ -31,6 +32,7 @@ import { validateChatHistory } from './chat-session-helpers'; import { generateContent, generateContentStream } from './generate-content'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; +import { mergeRequestOptions } from '../requests/request-options'; /** * Do not log a message for this error. @@ -75,7 +77,10 @@ export class ChatSession { * Sends a chat message and receives a non-streaming * {@link GenerateContentResult} */ - async sendMessage(request: string | Array): Promise { + async sendMessage( + request: string | Array, + singleRequestOptions?: SingleRequestOptions, + ): Promise { await this._sendPromise; const newContent = formatNewContent(request); const generateContentRequest: GenerateContentRequest = { @@ -90,7 +95,12 @@ export class ChatSession { // Add onto the chain. this._sendPromise = this._sendPromise .then(() => - generateContent(this._apiSettings, this.model, generateContentRequest, this.requestOptions), + generateContent( + this._apiSettings, + this.model, + generateContentRequest, + mergeRequestOptions(this.requestOptions, singleRequestOptions), + ), ) .then((result: GenerateContentResult) => { if (result.response.candidates && result.response.candidates.length > 0) { @@ -122,6 +132,7 @@ export class ChatSession { */ async sendMessageStream( request: string | Array, + singleRequestOptions?: SingleRequestOptions, ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -137,7 +148,7 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, - this.requestOptions, + mergeRequestOptions(this.requestOptions, singleRequestOptions), ); // Add onto the chain. diff --git a/packages/ai/lib/methods/count-tokens.ts b/packages/ai/lib/methods/count-tokens.ts index ebb9e31ee6..c0e022fa87 100644 --- a/packages/ai/lib/methods/count-tokens.ts +++ b/packages/ai/lib/methods/count-tokens.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { CountTokensRequest, CountTokensResponse, RequestOptions } from '../types'; +import { CountTokensRequest, CountTokensResponse, SingleRequestOptions } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; import { BackendType } from '../public-types'; @@ -27,14 +27,14 @@ import * as GoogleAIMapper from '../googleai-mappers'; * @param apiSettings The {@link ApiSettings} to use for the request. * @param model The model to use for the request. * @param params The {@link CountTokensRequest} to send. - * @param requestOptions The {@link RequestOptions} to use for the request. + * @param requestOptions The {@link SingleRequestOptions} to use for the request. * @returns The {@link CountTokensResponse} from the request. */ export async function countTokens( apiSettings: ApiSettings, model: string, params: CountTokensRequest, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise { let body: string = ''; switch (apiSettings.backend.backendType) { diff --git a/packages/ai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts index 4139e8c423..3e680460a8 100644 --- a/packages/ai/lib/methods/generate-content.ts +++ b/packages/ai/lib/methods/generate-content.ts @@ -22,7 +22,7 @@ import { GenerateContentResult, GenerateContentStreamResult, GenerationConfig, - RequestOptions, + SingleRequestOptions, } from '../types'; import { Task, makeRequest, ServerPromptTemplateTask } from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; @@ -55,14 +55,14 @@ function validateGenerationConfig(generationConfig?: GenerationConfig): void { * @param apiSettings The {@link ApiSettings} to use for the request. * @param model The model to use for the request. * @param params The {@link GenerateContentRequest} to send. - * @param requestOptions The {@link RequestOptions} to use for the request. + * @param requestOptions The {@link SingleRequestOptions} to use for the request. * @returns The {@link GenerateContentStreamResult} from the request. */ export async function generateContentStream( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise { validateGenerationConfig(params.generationConfig); if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { @@ -87,7 +87,7 @@ export async function generateContentStream( * @param apiSettings The {@link ApiSettings} to use for the request. * @param model The model to use for the request. * @param params The {@link GenerateContentRequest} to send. - * @param requestOptions The {@link RequestOptions} to use for the request. + * @param requestOptions The {@link SingleRequestOptions} to use for the request. * @returns The {@link GenerateContentResult} from the request. */ @@ -95,7 +95,7 @@ export async function generateContent( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise { validateGenerationConfig(params.generationConfig); if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { @@ -143,7 +143,7 @@ async function processGenerateContentResponse( * @param apiSettings The {@link ApiSettings} to use for the request. * @param templateId The ID of the server-side template to execute. * @param templateParams The parameters to populate the template with. - * @param requestOptions The {@link RequestOptions} to use for the request. + * @param requestOptions The {@link SingleRequestOptions} to use for the request. * @returns The {@link GenerateContentResult} from the request. * * @beta @@ -152,7 +152,7 @@ export async function templateGenerateContent( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise { const response = await makeRequest( { @@ -177,7 +177,7 @@ export async function templateGenerateContent( * @param apiSettings The {@link ApiSettings} to use for the request. * @param templateId The ID of the server-side template to execute. * @param templateParams The parameters to populate the template with. - * @param requestOptions The {@link RequestOptions} to use for the request. + * @param requestOptions The {@link SingleRequestOptions} to use for the request. * @returns The {@link GenerateContentStreamResult} from the request. * * @beta @@ -186,7 +186,7 @@ export async function templateGenerateContentStream( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise { const response = await makeRequest( { diff --git a/packages/ai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts index c3bba041e0..30d358e0e6 100644 --- a/packages/ai/lib/models/generative-model.ts +++ b/packages/ai/lib/models/generative-model.ts @@ -28,6 +28,7 @@ import { Part, RequestOptions, SafetySetting, + SingleRequestOptions, StartChatParams, Tool, ToolConfig, @@ -35,6 +36,7 @@ import { import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; import { formatGenerateContentInput, formatSystemInstruction } from '../requests/request-helpers'; +import { mergeRequestOptions } from '../requests/request-options'; import { AIModel } from './ai-model'; import { AI } from '../public-types'; @@ -66,6 +68,7 @@ export class GenerativeModel extends AIModel { */ async generateContent( request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions, ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContent( @@ -79,7 +82,7 @@ export class GenerativeModel extends AIModel { systemInstruction: this.systemInstruction, ...formattedParams, }, - this.requestOptions, + mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } @@ -91,6 +94,7 @@ export class GenerativeModel extends AIModel { */ async generateContentStream( request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions, ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContentStream( @@ -104,7 +108,7 @@ export class GenerativeModel extends AIModel { systemInstruction: this.systemInstruction, ...formattedParams, }, - this.requestOptions, + mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } @@ -138,8 +142,14 @@ export class GenerativeModel extends AIModel { */ async countTokens( request: CountTokensRequest | string | Array, + singleRequestOptions?: SingleRequestOptions, ): Promise { const formattedParams = formatGenerateContentInput(request); - return countTokens(this._apiSettings, this.model, formattedParams); + return countTokens( + this._apiSettings, + this.model, + formattedParams, + mergeRequestOptions(this.requestOptions, singleRequestOptions), + ); } } diff --git a/packages/ai/lib/models/imagen-model.ts b/packages/ai/lib/models/imagen-model.ts index 1cd41aeb56..2c5193f061 100644 --- a/packages/ai/lib/models/imagen-model.ts +++ b/packages/ai/lib/models/imagen-model.ts @@ -24,11 +24,13 @@ import { ImagenGenerationConfig, ImagenInlineImage, RequestOptions, + SingleRequestOptions, ImagenModelParams, ImagenGenerationResponse, ImagenSafetySettings, } from '../types'; import { AIModel } from './ai-model'; +import { mergeRequestOptions } from '../requests/request-options'; /** * Class for Imagen model APIs. @@ -101,7 +103,10 @@ export class ImagenModel extends AIModel { * * @beta */ - async generateImages(prompt: string): Promise> { + async generateImages( + prompt: string, + singleRequestOptions?: SingleRequestOptions, + ): Promise> { const body = createPredictRequestBody(prompt, { ...this.generationConfig, ...this.safetySettings, @@ -112,7 +117,7 @@ export class ImagenModel extends AIModel { task: Task.PREDICT, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions, + requestOptions: mergeRequestOptions(this.requestOptions, singleRequestOptions), }, JSON.stringify(body), ); @@ -141,6 +146,7 @@ export class ImagenModel extends AIModel { async generateImagesGCS( prompt: string, gcsURI: string, + singleRequestOptions?: SingleRequestOptions, ): Promise> { const body = createPredictRequestBody(prompt, { gcsURI, @@ -153,7 +159,7 @@ export class ImagenModel extends AIModel { task: Task.PREDICT, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions, + requestOptions: mergeRequestOptions(this.requestOptions, singleRequestOptions), }, JSON.stringify(body), ); diff --git a/packages/ai/lib/models/template-generative-model.ts b/packages/ai/lib/models/template-generative-model.ts index d97dcbcb30..d7d735929d 100644 --- a/packages/ai/lib/models/template-generative-model.ts +++ b/packages/ai/lib/models/template-generative-model.ts @@ -19,9 +19,10 @@ import { templateGenerateContent, templateGenerateContentStream, } from '../methods/generate-content'; -import { GenerateContentResult, RequestOptions } from '../types'; +import { GenerateContentResult, RequestOptions, SingleRequestOptions } from '../types'; import { AI, GenerateContentStreamResult } from '../public-types'; import { ApiSettings } from '../types/internal'; +import { mergeRequestOptions } from '../requests/request-options'; import { initApiSettings } from './utils'; /** @@ -63,12 +64,13 @@ export class TemplateGenerativeModel { async generateContent( templateId: string, templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions, ): Promise { return templateGenerateContent( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions, + mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } @@ -87,12 +89,13 @@ export class TemplateGenerativeModel { async generateContentStream( templateId: string, templateVariables: object, + singleRequestOptions?: SingleRequestOptions, ): Promise { return templateGenerateContentStream( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions, + mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } } diff --git a/packages/ai/lib/models/template-imagen-model.ts b/packages/ai/lib/models/template-imagen-model.ts index a788c91180..6fb0d1e3a3 100644 --- a/packages/ai/lib/models/template-imagen-model.ts +++ b/packages/ai/lib/models/template-imagen-model.ts @@ -15,10 +15,11 @@ * limitations under the License. */ -import { RequestOptions } from '../types'; +import { RequestOptions, SingleRequestOptions } from '../types'; import { AI, ImagenGenerationResponse, ImagenInlineImage } from '../public-types'; import { ApiSettings } from '../types/internal'; import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; +import { mergeRequestOptions } from '../requests/request-options'; import { handlePredictResponse } from '../requests/response-helpers'; import { initApiSettings } from './utils'; @@ -61,6 +62,7 @@ export class TemplateImagenModel { async generateImages( templateId: string, templateVariables: object, + singleRequestOptions?: SingleRequestOptions, ): Promise> { const response = await makeRequest( { @@ -68,7 +70,7 @@ export class TemplateImagenModel { templateId, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions, + requestOptions: mergeRequestOptions(this.requestOptions, singleRequestOptions), }, JSON.stringify({ inputs: templateVariables }), ); diff --git a/packages/ai/lib/requests/request-options.ts b/packages/ai/lib/requests/request-options.ts new file mode 100644 index 0000000000..0f69c3530e --- /dev/null +++ b/packages/ai/lib/requests/request-options.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { RequestOptions, SingleRequestOptions } from '../types'; + +export function mergeRequestOptions( + requestOptions?: RequestOptions, + singleRequestOptions?: SingleRequestOptions, +): SingleRequestOptions | undefined { + if (!requestOptions && !singleRequestOptions) { + return undefined; + } + return { + ...requestOptions, + ...singleRequestOptions, + }; +} diff --git a/packages/ai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts index 1a6b978b7d..84caf395c4 100644 --- a/packages/ai/lib/requests/request.ts +++ b/packages/ai/lib/requests/request.ts @@ -15,7 +15,7 @@ * limitations under the License. */ import { Platform } from 'react-native'; -import { AIErrorCode, ErrorDetails, RequestOptions } from '../types'; +import { AIErrorCode, ErrorDetails, SingleRequestOptions } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { @@ -48,7 +48,7 @@ export class RequestUrl { public task: Task, public apiSettings: ApiSettings, public stream: boolean, - public requestOptions?: RequestOptions, + public requestOptions?: SingleRequestOptions, ) {} toString(): string { // @ts-ignore @@ -112,13 +112,29 @@ export class RequestUrl { } } +function createAbortError(reason?: unknown): Error { + if (typeof DOMException !== 'undefined') { + return new DOMException( + reason == null ? 'Aborted' : String(reason), + 'AbortError', + ) as unknown as Error; + } + const error = new Error(reason == null ? 'Aborted' : String(reason)); + error.name = 'AbortError'; + return error; +} + +function getAbortSignalReason(signal?: AbortSignal): unknown { + return (signal as AbortSignal & { reason?: unknown } | undefined)?.reason; +} + export class TemplateRequestUrl { constructor( public templateId: string, public task: ServerPromptTemplateTask, public apiSettings: ApiSettings, public stream: boolean, - public requestOptions?: RequestOptions, + public requestOptions?: SingleRequestOptions, ) {} toString(): string { @@ -227,7 +243,7 @@ export async function constructRequest( apiSettings: ApiSettings, stream: boolean, body: string, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise<{ url: string; fetchOptions: RequestInit }> { const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); return { @@ -246,7 +262,7 @@ export async function constructTemplateRequest( apiSettings: ApiSettings, stream: boolean, body: string, - requestOptions?: RequestOptions, + requestOptions?: SingleRequestOptions, ): Promise<{ url: string; fetchOptions: RequestInit }> { const url = new TemplateRequestUrl(templateId, task, apiSettings, stream, requestOptions); return { @@ -266,7 +282,7 @@ export async function makeRequest( task: Task; apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + requestOptions?: SingleRequestOptions; }, body: string, ): Promise; @@ -277,7 +293,7 @@ export async function makeRequest( task: ServerPromptTemplateTask; apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + requestOptions?: SingleRequestOptions; }, body: string, ): Promise; @@ -289,14 +305,14 @@ export async function makeRequest( task: Task; apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + requestOptions?: SingleRequestOptions; } | { templateId: string; task: ServerPromptTemplateTask; apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + requestOptions?: SingleRequestOptions; }, body: string, ): Promise { @@ -320,6 +336,17 @@ export async function makeRequest( let response; let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; + const externalSignal = params.requestOptions?.signal; + let externalAbortReason: unknown; + if (externalSignal?.aborted) { + throw createAbortError(getAbortSignalReason(externalSignal)); + } + const abortController = new AbortController(); + const abortFromExternalSignal = (): void => { + externalAbortReason = getAbortSignalReason(externalSignal); + abortController.abort(); + }; + externalSignal?.addEventListener('abort', abortFromExternalSignal, { once: true }); try { const request = isTemplateRequest ? await constructTemplateRequest( @@ -343,7 +370,6 @@ export async function makeRequest( params.requestOptions?.timeout != null && params.requestOptions.timeout >= 0 ? params.requestOptions.timeout : DEFAULT_FETCH_TIMEOUT_MS; - const abortController = new AbortController(); fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); request.fetchOptions.signal = abortController.signal; const fetchOptions = params.stream @@ -406,6 +432,9 @@ export async function makeRequest( } } catch (e) { let err = e as Error; + if (e instanceof Error && e.name === 'AbortError') { + throw createAbortError(externalAbortReason); + } if ( (e as AIError).code !== AIErrorCode.FETCH_ERROR && (e as AIError).code !== AIErrorCode.API_NOT_ENABLED && @@ -420,6 +449,7 @@ export async function makeRequest( if (fetchTimeoutId) { clearTimeout(fetchTimeoutId); } + externalSignal?.removeEventListener('abort', abortFromExternalSignal); } return response; } diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 261ca52eb0..748bd4ab04 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -247,6 +247,21 @@ export interface RequestOptions { baseUrl?: string; } +/** + * Options that can be provided per-request. + * + * Options specified here will override any default {@link RequestOptions} + * configured on a model. + * + * @public + */ +export interface SingleRequestOptions extends RequestOptions { + /** + * An `AbortSignal` instance that allows cancelling ongoing requests. + */ + signal?: AbortSignal; +} + /** * Defines a tool that model can call to access external knowledge. * @public From 6651e01594666e94315d20550f247f715bcbbbcb Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:01:34 -0500 Subject: [PATCH 06/15] refactor(ai): move getHistory to ChatSessionBase --- .github/scripts/compare-types/configs/ai.ts | 10 ------ packages/ai/__tests__/exported-types.test.ts | 5 +++ packages/ai/lib/index.ts | 2 +- packages/ai/lib/methods/chat-session.ts | 35 ++++++++++++++------ 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 7d1cd40f09..6e041c7ea2 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,11 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'ChatSessionBase', - reason: - 'Base class used by the firebase-js-sdk template chat implementation. RN Firebase exposes its concrete chat session surface instead.', - }, { name: 'StartTemplateChatParams', reason: @@ -185,11 +180,6 @@ const config: PackageConfig = { reason: 'Both packages expose the same URL retrieval status constants, but the generated declaration text differs (`string`-valued object in JS SDK vs readonly literal constants in RN).', }, - { - name: 'ChatSession', - reason: - 'RN Firebase chat sessions expose `getHistory()` directly on the concrete class while the firebase-js-sdk inherits it from `ChatSessionBase`.', - }, { name: 'FunctionDeclaration', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index d96ad6b27a..b81d76d6d6 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -27,6 +27,7 @@ import { getAI, getGenerativeModel, ChatSession, + ChatSessionBase, GoogleAIBackend, VertexAIBackend, // Types that exist - imported for type checking @@ -159,6 +160,10 @@ describe('AI', function () { expect(ChatSession).toBeDefined(); }); + it('`ChatSessionBase` class is properly exposed to end user', function () { + expect(ChatSessionBase).toBeDefined(); + }); + it('`GoogleAIBackend` class is properly exposed to end user', function () { expect(GoogleAIBackend).toBeDefined(); }); diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 67fccfdb88..5a8688a929 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -32,7 +32,7 @@ import { import { WebSocketHandlerImpl } from './websocket'; export * from './public-types'; -export { ChatSession } from './methods/chat-session'; +export { ChatSession, ChatSessionBase } from './methods/chat-session'; export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index 7f367e81ba..6a9348ada3 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -45,18 +45,11 @@ const SILENT_ERROR = 'SILENT_ERROR'; * * @public */ -export class ChatSession { - private _apiSettings: ApiSettings; - private _history: Content[] = []; - private _sendPromise: Promise = Promise.resolve(); +export class ChatSessionBase { + protected _history: Content[] = []; + protected _sendPromise: Promise = Promise.resolve(); - constructor( - apiSettings: ApiSettings, - public model: string, - public params?: StartChatParams, - public requestOptions?: RequestOptions, - ) { - this._apiSettings = apiSettings; + constructor(public params?: ParamsType, public requestOptions?: RequestOptions) { if (params?.history) { validateChatHistory(params.history); this._history = params.history; @@ -72,6 +65,26 @@ export class ChatSession { await this._sendPromise; return this._history; } +} + +/** + * ChatSession class that enables sending chat messages and stores + * history of sent and received messages so far. + * + * @public + */ +export class ChatSession extends ChatSessionBase { + private _apiSettings: ApiSettings; + + constructor( + apiSettings: ApiSettings, + public model: string, + public params?: StartChatParams, + public requestOptions?: RequestOptions, + ) { + super(params, requestOptions); + this._apiSettings = apiSettings; + } /** * Sends a chat message and receives a non-streaming From e22aac1646b59873b542d9ba8f60e8de841ea1dc Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:07:43 -0500 Subject: [PATCH 07/15] feat(ai): support FunctionResponse parts --- .github/scripts/compare-types/configs/ai.ts | 5 --- packages/ai/__tests__/exported-types.test.ts | 9 +++++ .../ai/__tests__/generate-content.test.ts | 36 +++++++++++++++++++ packages/ai/lib/types/content.ts | 1 + 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 6e041c7ea2..c315d05005 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -185,11 +185,6 @@ const config: PackageConfig = { reason: 'RN Firebase function declarations accept `ObjectSchemaInterface` only and do not expose the JS SDK `functionReference` auto-calling hook.', }, - { - name: 'FunctionResponse', - reason: - 'RN Firebase function responses omit the optional `parts` field from the JS SDK declaration and only expose the structured response payload.', - }, { name: 'GenerationConfig', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index b81d76d6d6..0603020ec9 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -298,6 +298,15 @@ describe('AI', function () { expect(typeof _typeCheck).toBeDefined(); }); + it('`FunctionResponse.parts` type is properly exposed to end user', function () { + const _typeCheck: FunctionResponse = { + name: 'getWeather', + response: { temperature: 72 }, + parts: [{ text: 'Weather lookup complete.' }], + }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`FunctionResponsePart` type is properly exposed to end user', function () { const _typeCheck: FunctionResponsePart = {} as FunctionResponsePart; expect(typeof _typeCheck).toBeDefined(); diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts index 30c9428b2d..46dad1413b 100644 --- a/packages/ai/__tests__/generate-content.test.ts +++ b/packages/ai/__tests__/generate-content.test.ts @@ -127,6 +127,42 @@ describe('generateContent()', () => { ); }); + it('passes FunctionResponse.parts through in request bodies', async () => { + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + const params: GenerateContentRequest = { + contents: [ + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + parts: [{ text: 'Weather lookup complete.' }], + }, + }, + ], + }, + ], + }; + + await generateContent(fakeApiSettings, 'model', params); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + }), + expect.stringContaining('"parts":[{"text":"Weather lookup complete."}]'), + ); + }); + it('long response with token details', async () => { const mockResponse = getMockResponse( BackendName.VertexAI, diff --git a/packages/ai/lib/types/content.ts b/packages/ai/lib/types/content.ts index 62bdef6231..6c32d549e1 100644 --- a/packages/ai/lib/types/content.ts +++ b/packages/ai/lib/types/content.ts @@ -192,6 +192,7 @@ export interface FunctionResponse { id?: string; name: string; response: object; + parts?: Part[]; } /** From a883afdd39a180bad4df3119fccbf9a3dcd1c070 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:16:55 -0500 Subject: [PATCH 08/15] feat(ai): expose TemplateChat APIs --- .github/scripts/compare-types/configs/ai.ts | 30 ---- packages/ai/__tests__/exported-types.test.ts | 52 ++++++ .../template-generative-model.test.ts | 90 +++++++++++ packages/ai/lib/index.ts | 2 +- packages/ai/lib/methods/chat-session.ts | 150 +++++++++++++++++- .../lib/models/template-generative-model.ts | 25 ++- packages/ai/lib/types/requests.ts | 65 +++++++- 7 files changed, 377 insertions(+), 37 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index c315d05005..7e36c2d0b4 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -101,31 +101,6 @@ const config: PackageConfig = { reason: 'Chrome Prompt API prompt options type used by browser-only on-device language model integration.', }, - { - name: 'StartTemplateChatParams', - reason: - 'Template chat startup parameters are part of the firebase-js-sdk template chat API, which RN Firebase does not currently expose.', - }, - { - name: 'TemplateChatSession', - reason: - 'Template chat sessions are not currently part of the RN Firebase public AI API.', - }, - { - name: 'TemplateFunctionDeclaration', - reason: - 'Template function declaration helpers are part of firebase-js-sdk template tooling that RN Firebase does not currently expose.', - }, - { - name: 'TemplateFunctionDeclarationsTool', - reason: - 'Template function declaration tools are part of firebase-js-sdk template tooling that RN Firebase does not currently expose.', - }, - { - name: 'TemplateTool', - reason: - 'Template tool unions are part of firebase-js-sdk template tooling that RN Firebase does not currently expose.', - }, ], extraInRN: [ { @@ -195,11 +170,6 @@ const config: PackageConfig = { reason: 'RN Firebase does not currently expose `maxSequentialFunctionCalls`, so its request options are limited to timeout and base URL.', }, - { - name: 'TemplateGenerativeModel', - reason: - 'RN Firebase template generative models do not currently expose `startChat`, so template chat sessions remain absent.', - }, { name: 'UsageMetadata', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 0603020ec9..03a0b54143 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -28,6 +28,8 @@ import { getGenerativeModel, ChatSession, ChatSessionBase, + TemplateChatSession, + TemplateGenerativeModel, GoogleAIBackend, VertexAIBackend, // Types that exist - imported for type checking @@ -84,6 +86,10 @@ import { Segment, SingleRequestOptions, StartChatParams, + StartTemplateChatParams, + TemplateFunctionDeclaration, + TemplateFunctionDeclarationsTool, + TemplateTool, TextPart, ThinkingConfig, ThinkingLevel, @@ -164,6 +170,14 @@ describe('AI', function () { expect(ChatSessionBase).toBeDefined(); }); + it('`TemplateChatSession` class is properly exposed to end user', function () { + expect(TemplateChatSession).toBeDefined(); + }); + + it('`TemplateGenerativeModel` class is properly exposed to end user', function () { + expect(TemplateGenerativeModel).toBeDefined(); + }); + it('`GoogleAIBackend` class is properly exposed to end user', function () { expect(GoogleAIBackend).toBeDefined(); }); @@ -472,6 +486,44 @@ describe('AI', function () { expect(typeof _typeCheck).toBeDefined(); }); + it('`StartTemplateChatParams` type is properly exposed to end user', function () { + const _typeCheck: StartTemplateChatParams = { + templateId: 'my-template', + templateVariables: { city: 'London' }, + }; + expect(typeof _typeCheck).toBe('object'); + }); + + it('`TemplateFunctionDeclaration` type is properly exposed to end user', function () { + const _typeCheck: TemplateFunctionDeclaration = { + name: 'getWeather', + parameters: { + type: 'object', + properties: { + city: { + type: SchemaType.STRING, + }, + }, + }, + functionReference: () => ({ temperature: 72 }), + }; + expect(typeof _typeCheck).toBe('object'); + }); + + it('`TemplateFunctionDeclarationsTool` type is properly exposed to end user', function () { + const _typeCheck: TemplateFunctionDeclarationsTool = { + functionDeclarations: [{ name: 'getWeather' }], + }; + expect(typeof _typeCheck).toBe('object'); + }); + + it('`TemplateTool` type is properly exposed to end user', function () { + const _typeCheck: TemplateTool = { + functionDeclarations: [{ name: 'getWeather' }], + }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`TextPart` type is properly exposed to end user', function () { const _typeCheck: TextPart = {} as TextPart; expect(typeof _typeCheck).toBeDefined(); diff --git a/packages/ai/__tests__/template-generative-model.test.ts b/packages/ai/__tests__/template-generative-model.test.ts index 95476efd3e..1c973d1e0b 100644 --- a/packages/ai/__tests__/template-generative-model.test.ts +++ b/packages/ai/__tests__/template-generative-model.test.ts @@ -144,4 +144,94 @@ describe('TemplateGenerativeModel', function () { ); }); }); + + describe('startChat', function () { + it('should create a template chat session with the configured request options', function () { + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + const chat = model.startChat({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + }); + + expect(chat.params).toEqual({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + }); + expect(chat.requestOptions).toEqual({ timeout: 5000 }); + }); + + it('should call templateGenerateContent with template chat parameters', async function () { + const templateGenerateContentSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContent') + .mockResolvedValue({ + response: { + candidates: [{ content: { parts: [{ text: 'hello back' }] } }], + }, + } as any); + const model = new TemplateGenerativeModel(fakeAI, { + timeout: 5000, + baseUrl: 'https://model.example.com', + }); + const chat = model.startChat({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + }); + const controller = new AbortController(); + + await chat.sendMessage('hello', { + timeout: 2000, + signal: controller.signal, + }); + + expect(templateGenerateContentSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + expect.objectContaining({ + inputs: TEMPLATE_VARS, + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + }), + { + timeout: 2000, + baseUrl: 'https://model.example.com', + signal: controller.signal, + }, + ); + await expect(chat.getHistory()).resolves.toEqual([ + { role: 'user', parts: [{ text: 'hello' }] }, + { role: 'model', parts: [{ text: 'hello back' }] }, + ]); + }); + + it('should call templateGenerateContentStream with template chat parameters', async function () { + const templateGenerateContentStreamSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContentStream') + .mockResolvedValue({ + response: Promise.resolve({ + candidates: [{ content: { parts: [{ text: 'stream back' }] } }], + }), + } as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + const chat = model.startChat({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + }); + + const result = await chat.sendMessageStream('hello'); + await result.response; + + expect(templateGenerateContentStreamSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + expect.objectContaining({ + inputs: TEMPLATE_VARS, + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + }), + { timeout: 5000 }, + ); + await expect(chat.getHistory()).resolves.toEqual([ + { role: 'user', parts: [{ text: 'hello' }] }, + { role: 'model', parts: [{ text: 'stream back' }] }, + ]); + }); + }); }); diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 5a8688a929..acd94942ce 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -32,7 +32,7 @@ import { import { WebSocketHandlerImpl } from './websocket'; export * from './public-types'; -export { ChatSession, ChatSessionBase } from './methods/chat-session'; +export { ChatSession, ChatSessionBase, TemplateChatSession } from './methods/chat-session'; export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index 6a9348ada3..f31c59a9b3 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -24,12 +24,18 @@ import { RequestOptions, SingleRequestOptions, StartChatParams, + StartTemplateChatParams, EnhancedGenerateContentResponse, } from '../types'; import { formatNewContent } from '../requests/request-helpers'; import { formatBlockErrorMessage } from '../requests/response-helpers'; import { validateChatHistory } from './chat-session-helpers'; -import { generateContent, generateContentStream } from './generate-content'; +import { + generateContent, + generateContentStream, + templateGenerateContent, + templateGenerateContentStream, +} from './generate-content'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; import { mergeRequestOptions } from '../requests/request-options'; @@ -49,7 +55,10 @@ export class ChatSessionBase = Promise.resolve(); - constructor(public params?: ParamsType, public requestOptions?: RequestOptions) { + constructor( + public params?: ParamsType, + public requestOptions?: RequestOptions, + ) { if (params?.history) { validateChatHistory(params.history); this._history = params.history; @@ -204,3 +213,140 @@ export class ChatSession extends ChatSessionBase { return streamPromise; } } + +/** + * ChatSession class for server-side templates that enables sending chat + * messages and stores history of sent and received messages so far. + * + * @beta + */ +export class TemplateChatSession extends ChatSessionBase { + private _apiSettings: ApiSettings; + + constructor( + apiSettings: ApiSettings, + public params: StartTemplateChatParams, + public requestOptions?: RequestOptions, + ) { + super(params, requestOptions); + this._apiSettings = apiSettings; + } + + /** + * Sends a chat message and receives a non-streaming + * {@link GenerateContentResult} + */ + async sendMessage( + request: string | Array, + singleRequestOptions?: SingleRequestOptions, + ): Promise { + await this._sendPromise; + const newContent = formatNewContent(request); + const templateParams = this._buildTemplateChatRequest(newContent); + let finalResult = {} as GenerateContentResult; + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => + templateGenerateContent( + this._apiSettings, + this.params.templateId, + templateParams, + mergeRequestOptions(this.requestOptions, singleRequestOptions), + ), + ) + .then((result: GenerateContentResult) => { + if (result.response.candidates && result.response.candidates.length > 0) { + this._history.push(newContent); + const responseContent: Content = { + parts: result.response.candidates?.[0]?.content.parts || [], + // Response seems to come back without a role set. + role: result.response.candidates?.[0]?.content.role || 'model', + }; + this._history.push(responseContent); + } else { + const blockErrorMessage = formatBlockErrorMessage(result.response); + if (blockErrorMessage) { + logger.warn( + `sendMessage() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, + ); + } + } + finalResult = result; + }); + await this._sendPromise; + return finalResult; + } + + /** + * Sends a chat message and receives the response as a + * {@link GenerateContentStreamResult} containing an iterable stream + * and a response promise. + */ + async sendMessageStream( + request: string | Array, + singleRequestOptions?: SingleRequestOptions, + ): Promise { + await this._sendPromise; + const newContent = formatNewContent(request); + const templateParams = this._buildTemplateChatRequest(newContent); + const streamPromise = templateGenerateContentStream( + this._apiSettings, + this.params.templateId, + templateParams, + mergeRequestOptions(this.requestOptions, singleRequestOptions), + ); + + // Add onto the chain. + this._sendPromise = this._sendPromise + .then(() => streamPromise) + // This must be handled to avoid unhandled rejection, but jump + // to the final catch block with a label to not log this error. + .catch(_ignored => { + throw new Error(SILENT_ERROR); + }) + .then(streamResult => streamResult.response) + .then((response: EnhancedGenerateContentResponse) => { + if (response.candidates && response.candidates.length > 0) { + this._history.push(newContent); + const responseContent = { ...response.candidates[0]?.content }; + // Response seems to come back without a role set. + if (!responseContent.role) { + responseContent.role = 'model'; + } + this._history.push(responseContent as Content); + } else { + const blockErrorMessage = formatBlockErrorMessage(response); + if (blockErrorMessage) { + logger.warn( + `sendMessageStream() was unsuccessful. ${blockErrorMessage}. Inspect response object for details.`, + ); + } + } + }) + .catch(e => { + // Errors in streamPromise are already catchable by the user as + // streamPromise is returned. + // Avoid duplicating the error message in logs. + if (e.message !== SILENT_ERROR) { + // Users do not have access to _sendPromise to catch errors + // downstream from streamPromise, so they should not throw. + logger.error(e); + } + }); + return streamPromise; + } + + private _buildTemplateChatRequest(newContent: Content): object { + return { + ...(this.params.templateVariables !== undefined + ? { inputs: this.params.templateVariables } + : {}), + safetySettings: this.params.safetySettings, + generationConfig: this.params.generationConfig, + tools: this.params.tools, + toolConfig: this.params.toolConfig, + systemInstruction: this.params.systemInstruction, + contents: [...this._history, newContent], + }; + } +} diff --git a/packages/ai/lib/models/template-generative-model.ts b/packages/ai/lib/models/template-generative-model.ts index d7d735929d..59de48f16a 100644 --- a/packages/ai/lib/models/template-generative-model.ts +++ b/packages/ai/lib/models/template-generative-model.ts @@ -19,11 +19,17 @@ import { templateGenerateContent, templateGenerateContentStream, } from '../methods/generate-content'; -import { GenerateContentResult, RequestOptions, SingleRequestOptions } from '../types'; +import { + GenerateContentResult, + RequestOptions, + SingleRequestOptions, + StartTemplateChatParams, +} from '../types'; import { AI, GenerateContentStreamResult } from '../public-types'; import { ApiSettings } from '../types/internal'; import { mergeRequestOptions } from '../requests/request-options'; import { initApiSettings } from './utils'; +import { TemplateChatSession } from '../methods/chat-session'; /** * {@link GenerativeModel} APIs that execute on a server-side template. @@ -63,7 +69,7 @@ export class TemplateGenerativeModel { */ async generateContent( templateId: string, - templateVariables: object, // anything! + templateVariables: Record, singleRequestOptions?: SingleRequestOptions, ): Promise { return templateGenerateContent( @@ -88,7 +94,7 @@ export class TemplateGenerativeModel { */ async generateContentStream( templateId: string, - templateVariables: object, + templateVariables: Record, singleRequestOptions?: SingleRequestOptions, ): Promise { return templateGenerateContentStream( @@ -98,4 +104,17 @@ export class TemplateGenerativeModel { mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } + + /** + * Starts a {@link TemplateChatSession} that will use this template to + * respond to messages. + * + * @param params - Configurations for the chat, including the template + * ID and input variables. + * + * @beta + */ + startChat(params: StartTemplateChatParams): TemplateChatSession { + return new TemplateChatSession(this._apiSettings, params, this.requestOptions); + } } diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 748bd4ab04..2ba29c6f4a 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { TypedSchema } from '../requests/schema-builder'; +import { ObjectSchema, TypedSchema } from '../requests/schema-builder'; import { Content, Part } from './content'; import { FunctionCallingMode, @@ -212,6 +212,22 @@ export interface StartChatParams extends BaseParams { systemInstruction?: string | Part | Content; } +/** + * Params for {@link TemplateGenerativeModel.startChat}. + * @beta + */ +export interface StartTemplateChatParams extends Omit { + /** + * The ID of the server-side template to execute. + */ + templateId: string; + /** + * A key-value map of variables to populate the template with. + */ + templateVariables?: Record; + tools?: TemplateTool[]; +} + /** * Params for calling {@link GenerativeModel.countTokens} * @public @@ -385,6 +401,53 @@ export interface FunctionDeclarationsTool { functionDeclarations?: FunctionDeclaration[]; } +/** + * Structured representation of a template function declaration. + * @beta + */ +export interface TemplateFunctionDeclaration { + /** + * The name of the function to call. Must start with a letter or an + * underscore. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with + * a max length of 64. + */ + name: string; + /** + * Description is intentionally unsupported for template function declarations. + */ + description?: never; + /** + * Optional. Describes the parameters to this function in JSON Schema Object + * format. + */ + parameters?: ObjectSchema | ObjectSchemaRequest; + /** + * Reference to an actual function to call. Specifying this will cause the + * function to be called automatically when requested by the model. + */ + // eslint-disable-next-line @typescript-eslint/no-unsafe-function-type -- Matches firebase-js-sdk public API. + functionReference?: Function; +} + +/** + * A piece of code that enables the system to interact with external systems. + * @beta + */ +export interface TemplateFunctionDeclarationsTool { + /** + * Optional. One or more function declarations + * to be passed to the server-side template execution. + */ + functionDeclarations?: TemplateFunctionDeclaration[]; +} + +/** + * Defines a tool that a {@link TemplateGenerativeModel} can call + * to access external knowledge. + * @beta + */ +export type TemplateTool = TemplateFunctionDeclarationsTool; + /** * Tool config. This config is shared for all tools provided in the request. * @public From 90613ad93e3bd43b310cb2ab7a3f7b0801563cb4 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:27:03 -0500 Subject: [PATCH 09/15] feat(ai): support responseJsonSchema generation config --- .github/scripts/compare-types/configs/ai.ts | 8 +--- packages/ai/__tests__/exported-types.test.ts | 12 +++++- .../ai/__tests__/generative-model.test.ts | 37 +++++++++++++++++++ packages/ai/lib/types/requests.ts | 10 +++++ 4 files changed, 58 insertions(+), 9 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 7e36c2d0b4..384e9755f5 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -68,8 +68,7 @@ const config: PackageConfig = { }, { name: 'LanguageModelExpected', - reason: - 'Chrome Prompt API type tied to browser-only on-device language model integration.', + reason: 'Chrome Prompt API type tied to browser-only on-device language model integration.', }, { name: 'LanguageModelMessage', @@ -160,11 +159,6 @@ const config: PackageConfig = { reason: 'RN Firebase function declarations accept `ObjectSchemaInterface` only and do not expose the JS SDK `functionReference` auto-calling hook.', }, - { - name: 'GenerationConfig', - reason: - 'RN Firebase does not currently expose the JS SDK `responseJsonSchema` generation config field.', - }, { name: 'RequestOptions', reason: diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 03a0b54143..f8a50cdcf6 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -352,8 +352,16 @@ describe('AI', function () { }); it('`GenerationConfig` type is properly exposed to end user', function () { - const _typeCheck: GenerationConfig = {} as GenerationConfig; - expect(typeof _typeCheck).toBeDefined(); + const _typeCheck: GenerationConfig = { + responseMimeType: 'application/json', + responseJsonSchema: { + type: 'object', + properties: { + answer: { type: 'string' }, + }, + }, + }; + expect(typeof _typeCheck).toBe('object'); }); it('`GenerativeContentBlob` type is properly exposed to end user', function () { diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index 833059ca5a..02b2600233 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -194,6 +194,43 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('passes responseJsonSchema through to generateContent', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + responseMimeType: 'application/json', + responseJsonSchema: { + type: 'object', + properties: { + answer: { type: 'string' }, + }, + required: ['answer'], + }, + }, + }); + const mockResponse = getMockResponse( + BackendName.VertexAI, + 'unary-success-basic-reply-short.json', + ); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValue(mockResponse as Response); + + await genModel.generateContent('hello'); + + expect(makeRequestStub).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), + expect.stringContaining('"responseJsonSchema":{"type":"object"'), + ); + makeRequestStub.mockRestore(); + }); + it('throws when thinkingBudget and thinkingLevel are both set', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 2ba29c6f4a..501bce1030 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -116,6 +116,16 @@ export interface GenerationConfig { * this is limited to `application/json` and `text/x.enum`. */ responseSchema?: TypedSchema | SchemaRequest; + /** + * Output schema of the generated response. This is an alternative to + * `responseSchema` that accepts [JSON Schema](https://json-schema.org/). + * + * If set, `responseSchema` must be omitted, but `responseMimeType` + * is required and must be set to `application/json`. + */ + responseJsonSchema?: { + [key: string]: unknown; + }; /** * Generation modalities to be returned in generation responses. * From 02f929f3b41f8bbe2eb020ffbab77c7cb5fe2230 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:34:34 -0500 Subject: [PATCH 10/15] feat(ai): expose UsageMetadata token details --- .github/scripts/compare-types/configs/ai.ts | 5 ---- packages/ai/__tests__/exported-types.test.ts | 12 ++++++++-- .../ai/__tests__/generate-content.test.ts | 23 +++++++++++++++---- .../ai/__tests__/googleai-mappers.test.ts | 5 ++++ packages/ai/lib/types/responses.ts | 20 ++++++++++++++++ 5 files changed, 54 insertions(+), 11 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index 384e9755f5..d4e754fe15 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -164,11 +164,6 @@ const config: PackageConfig = { reason: 'RN Firebase does not currently expose `maxSequentialFunctionCalls`, so its request options are limited to timeout and base URL.', }, - { - name: 'UsageMetadata', - reason: - 'RN Firebase usage metadata does not currently surface tool-use and cache token accounting fields that are present in the JS SDK declaration.', - }, ], }; diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index f8a50cdcf6..82a5c843fe 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -563,8 +563,16 @@ describe('AI', function () { }); it('`UsageMetadata` type is properly exposed to end user', function () { - const _typeCheck: UsageMetadata = {} as UsageMetadata; - expect(typeof _typeCheck).toBeDefined(); + const _typeCheck: UsageMetadata = { + promptTokenCount: 5, + candidatesTokenCount: 7, + totalTokenCount: 12, + toolUsePromptTokenCount: 3, + toolUsePromptTokensDetails: [{ modality: Modality.TEXT, tokenCount: 3 }], + cachedContentTokenCount: 2, + cacheTokensDetails: [{ modality: Modality.TEXT, tokenCount: 2 }], + }; + expect(typeof _typeCheck).toBe('object'); }); it('`VideoMetadata` type is properly exposed to end user', function () { diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts index 46dad1413b..d62a95feac 100644 --- a/packages/ai/__tests__/generate-content.test.ts +++ b/packages/ai/__tests__/generate-content.test.ts @@ -167,10 +167,19 @@ describe('generateContent()', () => { const mockResponse = getMockResponse( BackendName.VertexAI, 'unary-success-basic-response-long-usage-metadata.json', - ); - const makeRequestStub = jest - .spyOn(request, 'makeRequest') - .mockResolvedValue(mockResponse as Response); + ) as Response; + const mockResponseJson = await mockResponse.json(); + mockResponseJson.usageMetadata = { + ...mockResponseJson.usageMetadata, + toolUsePromptTokenCount: 3, + toolUsePromptTokensDetails: [{ modality: 'TEXT', tokenCount: 3 }], + cachedContentTokenCount: 5, + cacheTokensDetails: [{ modality: 'TEXT', tokenCount: 5 }], + }; + const makeRequestStub = jest.spyOn(request, 'makeRequest').mockResolvedValue({ + ...mockResponse, + json: () => Promise.resolve(mockResponseJson), + } as Response); const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); expect(result.response.usageMetadata?.totalTokenCount).toEqual(1913); expect(result.response.usageMetadata?.candidatesTokenCount).toEqual(76); @@ -178,6 +187,12 @@ describe('generateContent()', () => { expect(result.response.usageMetadata?.promptTokensDetails?.[0]?.tokenCount).toEqual(1806); expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.modality).toEqual('TEXT'); expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.tokenCount).toEqual(76); + expect(result.response.usageMetadata).toMatchObject({ + toolUsePromptTokenCount: 3, + toolUsePromptTokensDetails: [{ modality: 'TEXT', tokenCount: 3 }], + cachedContentTokenCount: 5, + cacheTokensDetails: [{ modality: 'TEXT', tokenCount: 5 }], + }); expect(makeRequestStub).toHaveBeenCalledWith( expect.objectContaining({ model: 'model', diff --git a/packages/ai/__tests__/googleai-mappers.test.ts b/packages/ai/__tests__/googleai-mappers.test.ts index ac1d77a9c2..3afcd0dcbb 100644 --- a/packages/ai/__tests__/googleai-mappers.test.ts +++ b/packages/ai/__tests__/googleai-mappers.test.ts @@ -37,6 +37,7 @@ import { HarmCategory, HarmProbability, HarmSeverity, + Modality, PromptFeedback, SafetyRating, URLRetrievalStatus, @@ -168,7 +169,11 @@ describe('Google AI Mappers', () => { usageMetadata: { promptTokenCount: 5, candidatesTokenCount: 0, + toolUsePromptTokenCount: 3, totalTokenCount: 5, + toolUsePromptTokensDetails: [{ modality: Modality.TEXT, tokenCount: 3 }], + cachedContentTokenCount: 2, + cacheTokensDetails: [{ modality: Modality.TEXT, tokenCount: 2 }], }, }; const mappedResponse = mapGenerateContentResponse(googleAIResponse); diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts index 2c88b78d5e..ed7e0fa823 100644 --- a/packages/ai/lib/types/responses.ts +++ b/packages/ai/lib/types/responses.ts @@ -117,8 +117,28 @@ export interface UsageMetadata { */ thoughtsTokenCount?: number; totalTokenCount: number; + /** + * The number of tokens used by tools. + */ + toolUsePromptTokenCount?: number; promptTokensDetails?: ModalityTokenCount[]; candidatesTokensDetails?: ModalityTokenCount[]; + /** + * A list of tokens used by tools, broken down by modality. + */ + toolUsePromptTokensDetails?: ModalityTokenCount[]; + /** + * The number of tokens in the prompt that were served from the cache. + * If implicit caching is not active or no content was cached, + * this will be 0. + */ + cachedContentTokenCount?: number; + /** + * Detailed breakdown of the cached tokens by modality (for example, text or + * image). This list provides granular insight into which parts of + * the content were cached. + */ + cacheTokensDetails?: ModalityTokenCount[]; } /** From 016907f1cb3464d26e273bf9e190580a7ffda1b5 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:40:39 -0500 Subject: [PATCH 11/15] feat(ai): expose automatic function calling options --- .github/scripts/compare-types/configs/ai.ts | 10 -------- packages/ai/__tests__/exported-types.test.ts | 25 ++++++++++++++++++-- packages/ai/lib/types/requests.ts | 19 +++++++++++++-- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/.github/scripts/compare-types/configs/ai.ts b/.github/scripts/compare-types/configs/ai.ts index d4e754fe15..cffe2612d7 100644 --- a/.github/scripts/compare-types/configs/ai.ts +++ b/.github/scripts/compare-types/configs/ai.ts @@ -154,16 +154,6 @@ const config: PackageConfig = { reason: 'Both packages expose the same URL retrieval status constants, but the generated declaration text differs (`string`-valued object in JS SDK vs readonly literal constants in RN).', }, - { - name: 'FunctionDeclaration', - reason: - 'RN Firebase function declarations accept `ObjectSchemaInterface` only and do not expose the JS SDK `functionReference` auto-calling hook.', - }, - { - name: 'RequestOptions', - reason: - 'RN Firebase does not currently expose `maxSequentialFunctionCalls`, so its request options are limited to timeout and base URL.', - }, ], }; diff --git a/packages/ai/__tests__/exported-types.test.ts b/packages/ai/__tests__/exported-types.test.ts index 82a5c843fe..75ff3449c3 100644 --- a/packages/ai/__tests__/exported-types.test.ts +++ b/packages/ai/__tests__/exported-types.test.ts @@ -73,6 +73,7 @@ import { ModelParams, ObjectSchemaInterface, ObjectSchemaRequest, + ObjectSchema, PromptFeedback, RequestOptions, RetrievedContextAttribution, @@ -82,6 +83,7 @@ import { SchemaParams, SchemaRequest, SchemaShared, + Schema, SearchEntrypoint, Segment, SingleRequestOptions, @@ -424,14 +426,33 @@ describe('AI', function () { expect(typeof _typeCheck).toBe('object'); }); + it('`FunctionDeclaration` accepts ObjectSchema and functionReference', function () { + const parameters: ObjectSchema = Schema.object({ + properties: { + city: Schema.string(), + }, + }); + const _typeCheck: FunctionDeclaration = { + name: 'getWeather', + description: 'Gets weather for a city.', + parameters, + functionReference: () => ({ temperature: 72 }), + }; + expect(typeof _typeCheck).toBe('object'); + }); + it('`PromptFeedback` type is properly exposed to end user', function () { const _typeCheck: PromptFeedback = {} as PromptFeedback; expect(typeof _typeCheck).toBeDefined(); }); it('`RequestOptions` type is properly exposed to end user', function () { - const _typeCheck: RequestOptions = {} as RequestOptions; - expect(typeof _typeCheck).toBeDefined(); + const _typeCheck: RequestOptions = { + timeout: 1000, + baseUrl: 'https://example.com', + maxSequentialFunctionCalls: 3, + }; + expect(typeof _typeCheck).toBe('object'); }); it('`SingleRequestOptions` type is properly exposed to end user', function () { diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 501bce1030..e9a3dd13aa 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -25,7 +25,7 @@ import { ResponseModality, ThinkingLevel, } from './enums'; -import { ObjectSchemaInterface, ObjectSchemaRequest, SchemaRequest } from './schema'; +import { ObjectSchemaRequest, SchemaRequest } from './schema'; /** * Base parameters for a number of methods. @@ -271,6 +271,15 @@ export interface RequestOptions { * Base url for endpoint. Defaults to https://firebasevertexai.googleapis.com */ baseUrl?: string; + /** + * Limits amount of sequential function calls the SDK can make during automatic + * function calling, in order to prevent infinite loops. If not specified, + * this value defaults to 10. + * + * When it reaches this limit, it will return the last response received + * from the model, whether it is a text response or further function calls. + */ + maxSequentialFunctionCalls?: number; } /** @@ -320,7 +329,13 @@ export interface FunctionDeclaration { * format. Reflects the Open API 3.03 Parameter Object. Parameter names are * case-sensitive. For a function with no parameters, this can be left unset. */ - parameters?: ObjectSchemaInterface | ObjectSchemaRequest; + parameters?: ObjectSchema | ObjectSchemaRequest; + /** + * Reference to an actual function to call. Specifying this will cause the + * function to be called automatically when requested by the model. + */ + // eslint-disable-next-line @typescript-eslint/no-unsafe-function-type -- Matches firebase-js-sdk public API. + functionReference?: Function; } /** From e595077272ab171bdb098ffda842f7d6f1ca69d2 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:43:16 -0500 Subject: [PATCH 12/15] feat(ai): support generateContent function auto-calling --- .../ai/__tests__/generative-model.test.ts | 150 ++++++++++++++++++ packages/ai/lib/models/generative-model.ts | 117 ++++++++++++-- 2 files changed, 254 insertions(+), 13 deletions(-) diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index 02b2600233..aeab574ba6 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -21,6 +21,7 @@ import { AI, FunctionCallingMode, ThinkingLevel } from '../lib/public-types'; import * as request from '../lib/requests/request'; import { BackendName, getMockResponse } from './test-utils/mock-response'; import { VertexAIBackend } from '../lib/backend'; +import { GenerateContentResponse } from '../lib'; const fakeAI: AI = { app: { @@ -36,6 +37,10 @@ const fakeAI: AI = { location: 'us-central1', }; +function responseFromJson(json: GenerateContentResponse): Response { + return { json: async () => json } as Response; +} + describe('GenerativeModel', () => { it('passes CodeExecutionTool and URLContextTool with other tools through to generateContent', async function () { const genModel = new GenerativeModel(fakeAI, { @@ -362,6 +367,151 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('automatically calls functionReference from generateContent function calls', async () => { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + description: 'Gets weather for a city.', + functionReference: getWeather, + }, + ], + }, + ], + }); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValueOnce( + responseFromJson({ + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [ + { + functionCall: { + id: 'call-1', + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + }), + ) + .mockResolvedValueOnce( + responseFromJson({ + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + }, + ], + }), + ); + + const result = await genModel.generateContent('weather in London'); + + expect(result.response.text()).toBe('It is 72 degrees.'); + expect(getWeather).toHaveBeenCalledWith({ city: 'London' }); + expect(makeRequestStub).toHaveBeenCalledTimes(2); + const followUpBody = JSON.parse(makeRequestStub.mock.calls[1]![1] as string); + expect(followUpBody.contents).toEqual([ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + id: 'call-1', + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + id: 'call-1', + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + ]); + makeRequestStub.mockRestore(); + }); + + it('returns the latest response when maxSequentialFunctionCalls is reached', async () => { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const genModel = new GenerativeModel( + fakeAI, + { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + description: 'Gets weather for a city.', + functionReference: getWeather, + }, + ], + }, + ], + }, + { maxSequentialFunctionCalls: 1 }, + ); + const functionCallResponse: GenerateContentResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + }; + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValueOnce(responseFromJson(functionCallResponse)) + .mockResolvedValueOnce(responseFromJson(functionCallResponse)); + + const result = await genModel.generateContent('weather in London'); + + expect(result.response.functionCalls()).toEqual([ + { name: 'getWeather', args: { city: 'London' } }, + ]); + expect(getWeather).toHaveBeenCalledTimes(1); + expect(makeRequestStub).toHaveBeenCalledTimes(2); + makeRequestStub.mockRestore(); + }); + it('passes base model params through to ChatSession when there are no startChatParams', async () => { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', diff --git a/packages/ai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts index 30d358e0e6..fad8435fbc 100644 --- a/packages/ai/lib/models/generative-model.ts +++ b/packages/ai/lib/models/generative-model.ts @@ -20,6 +20,9 @@ import { Content, CountTokensRequest, CountTokensResponse, + FunctionCall, + FunctionDeclaration, + FunctionResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, @@ -40,6 +43,8 @@ import { mergeRequestOptions } from '../requests/request-options'; import { AIModel } from './ai-model'; import { AI } from '../public-types'; +const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10; + /** * Class for generative model APIs. * @public @@ -71,19 +76,17 @@ export class GenerativeModel extends AIModel { singleRequestOptions?: SingleRequestOptions, ): Promise { const formattedParams = formatGenerateContentInput(request); - return generateContent( - this._apiSettings, - this.model, - { - generationConfig: this.generationConfig, - safetySettings: this.safetySettings, - tools: this.tools, - toolConfig: this.toolConfig, - systemInstruction: this.systemInstruction, - ...formattedParams, - }, - mergeRequestOptions(this.requestOptions, singleRequestOptions), - ); + const params: GenerateContentRequest = { + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + ...formattedParams, + }; + const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions); + const result = await generateContent(this._apiSettings, this.model, params, requestOptions); + return this._generateContentWithAutomaticFunctionCalling(params, result, requestOptions); } /** @@ -152,4 +155,92 @@ export class GenerativeModel extends AIModel { mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } + + private async _generateContentWithAutomaticFunctionCalling( + params: GenerateContentRequest, + result: GenerateContentResult, + requestOptions?: SingleRequestOptions, + ): Promise { + let remainingFunctionCalls = + requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; + let currentParams = params; + let currentResult = result; + + while (remainingFunctionCalls > 0) { + const functionCalls = currentResult.response.functionCalls?.(); + if (!functionCalls?.length) { + return currentResult; + } + + const functionResponses = await this._callFunctionReferences( + currentParams.tools, + functionCalls, + ); + if (!functionResponses) { + return currentResult; + } + + const responseContent = currentResult.response.candidates?.[0]?.content; + if (!responseContent) { + return currentResult; + } + + remainingFunctionCalls -= 1; + currentParams = { + ...currentParams, + contents: [ + ...currentParams.contents, + responseContent, + { + role: 'function', + parts: functionResponses.map(functionResponse => ({ functionResponse })), + }, + ], + }; + currentResult = await generateContent( + this._apiSettings, + this.model, + currentParams, + requestOptions, + ); + } + + return currentResult; + } + + private async _callFunctionReferences( + tools: Tool[] | undefined, + functionCalls: FunctionCall[], + ): Promise { + const declarations = this._getFunctionDeclarationsWithReferences(tools); + if (!declarations.length) { + return undefined; + } + + const functionResponses: FunctionResponse[] = []; + for (const functionCall of functionCalls) { + const declaration = declarations.find(candidate => candidate.name === functionCall.name); + if (!declaration?.functionReference) { + return undefined; + } + + const response = (await declaration.functionReference(functionCall.args)) as object; + functionResponses.push({ + id: functionCall.id, + name: functionCall.name, + response, + }); + } + return functionResponses; + } + + private _getFunctionDeclarationsWithReferences(tools: Tool[] | undefined): FunctionDeclaration[] { + return ( + tools?.flatMap(tool => + 'functionDeclarations' in tool + ? (tool.functionDeclarations?.filter(declaration => declaration.functionReference) ?? []) + : [], + ) ?? [] + ); + } } From 130b5dd288453b09368c3b0cf167b618d3f3ba90 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:48:25 -0500 Subject: [PATCH 13/15] feat(ai): support chat function auto-calling --- .../ai/__tests__/generative-model.test.ts | 94 +++++++++++++ .../lib/methods/automatic-function-calling.ts | 128 ++++++++++++++++++ packages/ai/lib/methods/chat-session.ts | 22 ++- packages/ai/lib/models/generative-model.ts | 104 ++------------ 4 files changed, 248 insertions(+), 100 deletions(-) create mode 100644 packages/ai/lib/methods/automatic-function-calling.ts diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index aeab574ba6..b4cf0a7eb8 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -573,6 +573,100 @@ describe('GenerativeModel', () => { makeRequestStub.mockRestore(); }); + it('automatically calls functionReference from chat.sendMessage function calls', async () => { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + description: 'Gets weather for a city.', + functionReference: getWeather, + }, + ], + }, + ], + }); + const makeRequestStub = jest + .spyOn(request, 'makeRequest') + .mockResolvedValueOnce( + responseFromJson({ + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + }), + ) + .mockResolvedValueOnce( + responseFromJson({ + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + }, + ], + }), + ); + const chat = genModel.startChat(); + + const result = await chat.sendMessage('weather in London'); + const history = await chat.getHistory(); + + expect(result.response.text()).toBe('It is 72 degrees.'); + expect(getWeather).toHaveBeenCalledWith({ city: 'London' }); + expect(makeRequestStub).toHaveBeenCalledTimes(2); + expect(history).toEqual([ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + ]); + makeRequestStub.mockRestore(); + }); + it('passes CodeExecutionTool through to chat.sendMessage', async function () { const genModel = new GenerativeModel(fakeAI, { model: 'my-model', diff --git a/packages/ai/lib/methods/automatic-function-calling.ts b/packages/ai/lib/methods/automatic-function-calling.ts new file mode 100644 index 0000000000..948a37ce2f --- /dev/null +++ b/packages/ai/lib/methods/automatic-function-calling.ts @@ -0,0 +1,128 @@ +/** + * @license + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Content, + FunctionCall, + FunctionDeclaration, + FunctionResponse, + GenerateContentRequest, + GenerateContentResult, + SingleRequestOptions, + Tool, +} from '../types'; +import { ApiSettings } from '../types/internal'; +import { generateContent } from './generate-content'; + +const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10; + +export interface AutomaticFunctionCallingResult { + result: GenerateContentResult; + addedContents: Content[]; +} + +export async function generateContentWithAutomaticFunctionCalling( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + result: GenerateContentResult, + requestOptions?: SingleRequestOptions, +): Promise { + let remainingFunctionCalls = + requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; + let currentParams = params; + let currentResult = result; + const addedContents: Content[] = []; + + while (remainingFunctionCalls > 0) { + const functionCalls = currentResult.response.functionCalls?.(); + if (!functionCalls?.length) { + return { result: currentResult, addedContents }; + } + + const functionResponses = await callFunctionReferences(currentParams.tools, functionCalls); + if (!functionResponses) { + return { result: currentResult, addedContents }; + } + + const responseContent = getModelResponseContent(currentResult); + if (!responseContent) { + return { result: currentResult, addedContents }; + } + + remainingFunctionCalls -= 1; + const functionResponseContent: Content = { + role: 'function', + parts: functionResponses.map(functionResponse => ({ functionResponse })), + }; + addedContents.push(responseContent, functionResponseContent); + currentParams = { + ...currentParams, + contents: [...currentParams.contents, responseContent, functionResponseContent], + }; + currentResult = await generateContent(apiSettings, model, currentParams, requestOptions); + } + + return { result: currentResult, addedContents }; +} + +function getModelResponseContent(result: GenerateContentResult): Content | undefined { + const responseContent = result.response.candidates?.[0]?.content; + if (!responseContent) { + return undefined; + } + return { + parts: responseContent.parts || [], + role: responseContent.role || 'model', + }; +} + +async function callFunctionReferences( + tools: Tool[] | undefined, + functionCalls: FunctionCall[], +): Promise { + const declarations = getFunctionDeclarationsWithReferences(tools); + if (!declarations.length) { + return undefined; + } + + const functionResponses: FunctionResponse[] = []; + for (const functionCall of functionCalls) { + const declaration = declarations.find(candidate => candidate.name === functionCall.name); + if (!declaration?.functionReference) { + return undefined; + } + + const response = (await declaration.functionReference(functionCall.args)) as object; + functionResponses.push({ + id: functionCall.id, + name: functionCall.name, + response, + }); + } + return functionResponses; +} + +function getFunctionDeclarationsWithReferences(tools: Tool[] | undefined): FunctionDeclaration[] { + return ( + tools?.flatMap(tool => + 'functionDeclarations' in tool + ? (tool.functionDeclarations?.filter(declaration => declaration.functionReference) ?? []) + : [], + ) ?? [] + ); +} diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index f31c59a9b3..43cd355c20 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -36,6 +36,7 @@ import { templateGenerateContent, templateGenerateContentStream, } from './generate-content'; +import { generateContentWithAutomaticFunctionCalling } from './automatic-function-calling'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; import { mergeRequestOptions } from '../requests/request-options'; @@ -116,17 +117,26 @@ export class ChatSession extends ChatSessionBase { let finalResult = {} as GenerateContentResult; // Add onto the chain. this._sendPromise = this._sendPromise - .then(() => - generateContent( + .then(async () => { + const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions); + const result = await generateContent( this._apiSettings, this.model, generateContentRequest, - mergeRequestOptions(this.requestOptions, singleRequestOptions), - ), - ) - .then((result: GenerateContentResult) => { + requestOptions, + ); + return generateContentWithAutomaticFunctionCalling( + this._apiSettings, + this.model, + generateContentRequest, + result, + requestOptions, + ); + }) + .then(({ result, addedContents }) => { if (result.response.candidates && result.response.candidates.length > 0) { this._history.push(newContent); + this._history.push(...addedContents); const responseContent: Content = { parts: result.response.candidates?.[0]?.content.parts || [], // Response seems to come back without a role set. diff --git a/packages/ai/lib/models/generative-model.ts b/packages/ai/lib/models/generative-model.ts index fad8435fbc..e30f88511d 100644 --- a/packages/ai/lib/models/generative-model.ts +++ b/packages/ai/lib/models/generative-model.ts @@ -20,9 +20,6 @@ import { Content, CountTokensRequest, CountTokensResponse, - FunctionCall, - FunctionDeclaration, - FunctionResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, @@ -42,8 +39,7 @@ import { formatGenerateContentInput, formatSystemInstruction } from '../requests import { mergeRequestOptions } from '../requests/request-options'; import { AIModel } from './ai-model'; import { AI } from '../public-types'; - -const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10; +import { generateContentWithAutomaticFunctionCalling } from '../methods/automatic-function-calling'; /** * Class for generative model APIs. @@ -86,7 +82,15 @@ export class GenerativeModel extends AIModel { }; const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions); const result = await generateContent(this._apiSettings, this.model, params, requestOptions); - return this._generateContentWithAutomaticFunctionCalling(params, result, requestOptions); + return ( + await generateContentWithAutomaticFunctionCalling( + this._apiSettings, + this.model, + params, + result, + requestOptions, + ) + ).result; } /** @@ -155,92 +159,4 @@ export class GenerativeModel extends AIModel { mergeRequestOptions(this.requestOptions, singleRequestOptions), ); } - - private async _generateContentWithAutomaticFunctionCalling( - params: GenerateContentRequest, - result: GenerateContentResult, - requestOptions?: SingleRequestOptions, - ): Promise { - let remainingFunctionCalls = - requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; - let currentParams = params; - let currentResult = result; - - while (remainingFunctionCalls > 0) { - const functionCalls = currentResult.response.functionCalls?.(); - if (!functionCalls?.length) { - return currentResult; - } - - const functionResponses = await this._callFunctionReferences( - currentParams.tools, - functionCalls, - ); - if (!functionResponses) { - return currentResult; - } - - const responseContent = currentResult.response.candidates?.[0]?.content; - if (!responseContent) { - return currentResult; - } - - remainingFunctionCalls -= 1; - currentParams = { - ...currentParams, - contents: [ - ...currentParams.contents, - responseContent, - { - role: 'function', - parts: functionResponses.map(functionResponse => ({ functionResponse })), - }, - ], - }; - currentResult = await generateContent( - this._apiSettings, - this.model, - currentParams, - requestOptions, - ); - } - - return currentResult; - } - - private async _callFunctionReferences( - tools: Tool[] | undefined, - functionCalls: FunctionCall[], - ): Promise { - const declarations = this._getFunctionDeclarationsWithReferences(tools); - if (!declarations.length) { - return undefined; - } - - const functionResponses: FunctionResponse[] = []; - for (const functionCall of functionCalls) { - const declaration = declarations.find(candidate => candidate.name === functionCall.name); - if (!declaration?.functionReference) { - return undefined; - } - - const response = (await declaration.functionReference(functionCall.args)) as object; - functionResponses.push({ - id: functionCall.id, - name: functionCall.name, - response, - }); - } - return functionResponses; - } - - private _getFunctionDeclarationsWithReferences(tools: Tool[] | undefined): FunctionDeclaration[] { - return ( - tools?.flatMap(tool => - 'functionDeclarations' in tool - ? (tool.functionDeclarations?.filter(declaration => declaration.functionReference) ?? []) - : [], - ) ?? [] - ); - } } From 22a741c6ca899486220849c50f42c291753d161c Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:53:35 -0500 Subject: [PATCH 14/15] feat(ai): support streaming chat function auto-calling --- packages/ai/__tests__/chat-session.test.ts | 107 +++++++++++++++++- .../lib/methods/automatic-function-calling.ts | 66 ++++++++++- packages/ai/lib/methods/chat-session.ts | 25 +++- 3 files changed, 189 insertions(+), 9 deletions(-) diff --git a/packages/ai/__tests__/chat-session.test.ts b/packages/ai/__tests__/chat-session.test.ts index 8f06b0d033..8701a31a76 100644 --- a/packages/ai/__tests__/chat-session.test.ts +++ b/packages/ai/__tests__/chat-session.test.ts @@ -17,7 +17,7 @@ import { describe, expect, it, afterEach, jest } from '@jest/globals'; import * as generateContentMethods from '../lib/methods/generate-content'; -import { GenerateContentStreamResult } from '../lib/types'; +import { EnhancedGenerateContentResponse, GenerateContentStreamResult } from '../lib/types'; import { ChatSession } from '../lib/methods/chat-session'; import { ApiSettings } from '../lib/types/internal'; import { RequestOptions } from '../lib/types/requests'; @@ -35,6 +35,15 @@ const requestOptions: RequestOptions = { timeout: 1000, }; +function streamResult(response: EnhancedGenerateContentResponse): GenerateContentStreamResult { + return { + stream: (async function* () { + yield response; + })(), + response: Promise.resolve(response), + }; +} + describe('ChatSession', () => { afterEach(() => { jest.restoreAllMocks(); @@ -129,6 +138,102 @@ describe('ChatSession', () => { ); }); + it('automatically calls functionReference from stream function calls', async () => { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const functionCallResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }], + } as EnhancedGenerateContentResponse; + const finalResponse = { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + }, + ], + functionCalls: () => undefined, + } as EnhancedGenerateContentResponse; + const generateContentStreamStub = jest + .spyOn(generateContentMethods, 'generateContentStream') + .mockResolvedValueOnce(streamResult(functionCallResponse)) + .mockResolvedValueOnce(streamResult(finalResponse)); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + { + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + description: 'Gets weather for a city.', + functionReference: getWeather, + }, + ], + }, + ], + }, + requestOptions, + ); + + const result = await chatSession.sendMessageStream('weather in London'); + await result.response; + const history = await chatSession.getHistory(); + + expect(getWeather).toHaveBeenCalledWith({ city: 'London' }); + expect(generateContentStreamStub).toHaveBeenCalledTimes(2); + expect(history).toEqual([ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + ]); + }); + it('downstream sendPromise errors should log but not throw', async () => { const consoleStub = jest.spyOn(console, 'error').mockImplementation(() => {}); // make response undefined so that response.candidates errors diff --git a/packages/ai/lib/methods/automatic-function-calling.ts b/packages/ai/lib/methods/automatic-function-calling.ts index 948a37ce2f..e55832d17b 100644 --- a/packages/ai/lib/methods/automatic-function-calling.ts +++ b/packages/ai/lib/methods/automatic-function-calling.ts @@ -22,11 +22,13 @@ import { FunctionResponse, GenerateContentRequest, GenerateContentResult, + GenerateContentStreamResult, + EnhancedGenerateContentResponse, SingleRequestOptions, Tool, } from '../types'; import { ApiSettings } from '../types/internal'; -import { generateContent } from './generate-content'; +import { generateContent, generateContentStream } from './generate-content'; const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10; @@ -35,6 +37,11 @@ export interface AutomaticFunctionCallingResult { addedContents: Content[]; } +export interface AutomaticFunctionCallingStreamResult { + result: GenerateContentStreamResult; + addedContents: Content[]; +} + export async function generateContentWithAutomaticFunctionCalling( apiSettings: ApiSettings, model: string, @@ -80,8 +87,61 @@ export async function generateContentWithAutomaticFunctionCalling( return { result: currentResult, addedContents }; } -function getModelResponseContent(result: GenerateContentResult): Content | undefined { - const responseContent = result.response.candidates?.[0]?.content; +export async function generateContentStreamWithAutomaticFunctionCalling( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + result: GenerateContentStreamResult, + requestOptions?: SingleRequestOptions, +): Promise { + if (!getFunctionDeclarationsWithReferences(params.tools).length) { + return { result, addedContents: [] }; + } + + let remainingFunctionCalls = + requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; + let currentParams = params; + let currentResult = result; + const addedContents: Content[] = []; + + while (remainingFunctionCalls > 0) { + const response = await currentResult.response; + const functionCalls = response.functionCalls?.(); + if (!functionCalls?.length) { + return { result: currentResult, addedContents }; + } + + const functionResponses = await callFunctionReferences(currentParams.tools, functionCalls); + if (!functionResponses) { + return { result: currentResult, addedContents }; + } + + const responseContent = getModelResponseContent(response); + if (!responseContent) { + return { result: currentResult, addedContents }; + } + + remainingFunctionCalls -= 1; + const functionResponseContent: Content = { + role: 'function', + parts: functionResponses.map(functionResponse => ({ functionResponse })), + }; + addedContents.push(responseContent, functionResponseContent); + currentParams = { + ...currentParams, + contents: [...currentParams.contents, responseContent, functionResponseContent], + }; + currentResult = await generateContentStream(apiSettings, model, currentParams, requestOptions); + } + + return { result: currentResult, addedContents }; +} + +function getModelResponseContent( + responseOrResult: GenerateContentResult | EnhancedGenerateContentResponse, +): Content | undefined { + const response = 'response' in responseOrResult ? responseOrResult.response : responseOrResult; + const responseContent = response.candidates?.[0]?.content; if (!responseContent) { return undefined; } diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index 43cd355c20..7d85038998 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -36,7 +36,10 @@ import { templateGenerateContent, templateGenerateContentStream, } from './generate-content'; -import { generateContentWithAutomaticFunctionCalling } from './automatic-function-calling'; +import { + generateContentStreamWithAutomaticFunctionCalling, + generateContentWithAutomaticFunctionCalling, +} from './automatic-function-calling'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; import { mergeRequestOptions } from '../requests/request-options'; @@ -176,11 +179,20 @@ export class ChatSession extends ChatSessionBase { systemInstruction: this.params?.systemInstruction, contents: [...this._history, newContent], }; + const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions); const streamPromise = generateContentStream( this._apiSettings, this.model, generateContentRequest, - mergeRequestOptions(this.requestOptions, singleRequestOptions), + requestOptions, + ).then(result => + generateContentStreamWithAutomaticFunctionCalling( + this._apiSettings, + this.model, + generateContentRequest, + result, + requestOptions, + ), ); // Add onto the chain. @@ -191,10 +203,13 @@ export class ChatSession extends ChatSessionBase { .catch(_ignored => { throw new Error(SILENT_ERROR); }) - .then(streamResult => streamResult.response) - .then((response: EnhancedGenerateContentResponse) => { + .then(({ result, addedContents }) => + result.response.then(response => ({ response, addedContents })), + ) + .then(({ response, addedContents }) => { if (response.candidates && response.candidates.length > 0) { this._history.push(newContent); + this._history.push(...addedContents); const responseContent = { ...response.candidates[0]?.content }; // Response seems to come back without a role set. if (!responseContent.role) { @@ -220,7 +235,7 @@ export class ChatSession extends ChatSessionBase { logger.error(e); } }); - return streamPromise; + return (await streamPromise).result; } } From 1cb7139482998f06d12cb55ea8fa23b4e298a052 Mon Sep 17 00:00:00 2001 From: Mike Hardy Date: Sun, 17 May 2026 23:57:38 -0500 Subject: [PATCH 15/15] feat(ai): support template chat function auto-calling --- .../template-generative-model.test.ts | 266 ++++++++++++++++++ .../lib/methods/automatic-function-calling.ts | 126 ++++++++- packages/ai/lib/methods/chat-session.ts | 47 +++- 3 files changed, 424 insertions(+), 15 deletions(-) diff --git a/packages/ai/__tests__/template-generative-model.test.ts b/packages/ai/__tests__/template-generative-model.test.ts index 1c973d1e0b..316b61a3a3 100644 --- a/packages/ai/__tests__/template-generative-model.test.ts +++ b/packages/ai/__tests__/template-generative-model.test.ts @@ -21,6 +21,7 @@ import { AI } from '../lib/public-types'; import { VertexAIBackend } from '../lib/backend'; import { TemplateGenerativeModel } from '../lib/models/template-generative-model'; import * as generateContentMethods from '../lib/methods/generate-content'; +import { EnhancedGenerateContentResponse, GenerateContentStreamResult } from '../lib/types'; const fakeAI: AI = { app: { @@ -39,6 +40,15 @@ const fakeAI: AI = { const TEMPLATE_ID = 'my-template'; const TEMPLATE_VARS = { a: 1, b: '2' }; +function streamResult(response: EnhancedGenerateContentResponse): GenerateContentStreamResult { + return { + stream: (async function* () { + yield response; + })(), + response: Promise.resolve(response), + }; +} + describe('TemplateGenerativeModel', function () { afterEach(function () { jest.restoreAllMocks(); @@ -202,6 +212,134 @@ describe('TemplateGenerativeModel', function () { ]); }); + it('automatically calls functionReference from template chat function calls', async function () { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const functionCallResponse = { + candidates: [ + { + content: { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }], + } as EnhancedGenerateContentResponse; + const finalResponse = { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + }, + ], + text: () => 'It is 72 degrees.', + functionCalls: () => undefined, + } as EnhancedGenerateContentResponse; + const templateGenerateContentSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContent') + .mockResolvedValueOnce({ response: functionCallResponse }) + .mockResolvedValueOnce({ response: finalResponse }); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + const chat = model.startChat({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + functionReference: getWeather, + }, + ], + }, + ], + }); + + const result = await chat.sendMessage('weather in London'); + const history = await chat.getHistory(); + + expect(result.response.text()).toBe('It is 72 degrees.'); + expect(getWeather).toHaveBeenCalledWith({ city: 'London' }); + expect(templateGenerateContentSpy).toHaveBeenCalledTimes(2); + expect(templateGenerateContentSpy).toHaveBeenLastCalledWith( + model._apiSettings, + TEMPLATE_ID, + expect.objectContaining({ + inputs: TEMPLATE_VARS, + contents: [ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + ], + }), + { timeout: 5000 }, + ); + expect(history).toEqual([ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + ]); + }); + it('should call templateGenerateContentStream with template chat parameters', async function () { const templateGenerateContentStreamSpy = jest .spyOn(generateContentMethods, 'templateGenerateContentStream') @@ -233,5 +371,133 @@ describe('TemplateGenerativeModel', function () { { role: 'model', parts: [{ text: 'stream back' }] }, ]); }); + + it('automatically calls functionReference from streaming template chat function calls', async function () { + const getWeather = jest.fn<(args: object) => object>().mockReturnValue({ temperature: 72 }); + const functionCallResponse = { + candidates: [ + { + content: { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + }, + ], + functionCalls: () => [{ name: 'getWeather', args: { city: 'London' } }], + } as EnhancedGenerateContentResponse; + const finalResponse = { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + }, + ], + text: () => 'It is 72 degrees.', + functionCalls: () => undefined, + } as EnhancedGenerateContentResponse; + const templateGenerateContentStreamSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContentStream') + .mockResolvedValueOnce(streamResult(functionCallResponse)) + .mockResolvedValueOnce(streamResult(finalResponse)); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + const chat = model.startChat({ + templateId: TEMPLATE_ID, + templateVariables: TEMPLATE_VARS, + tools: [ + { + functionDeclarations: [ + { + name: 'getWeather', + functionReference: getWeather, + }, + ], + }, + ], + }); + + const result = await chat.sendMessageStream('weather in London'); + await result.response; + const history = await chat.getHistory(); + + expect(getWeather).toHaveBeenCalledWith({ city: 'London' }); + expect(templateGenerateContentStreamSpy).toHaveBeenCalledTimes(2); + expect(templateGenerateContentStreamSpy).toHaveBeenLastCalledWith( + model._apiSettings, + TEMPLATE_ID, + expect.objectContaining({ + inputs: TEMPLATE_VARS, + contents: [ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + ], + }), + { timeout: 5000 }, + ); + expect(history).toEqual([ + { + role: 'user', + parts: [{ text: 'weather in London' }], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'getWeather', + args: { city: 'London' }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'getWeather', + response: { temperature: 72 }, + }, + }, + ], + }, + { + role: 'model', + parts: [{ text: 'It is 72 degrees.' }], + }, + ]); + }); }); }); diff --git a/packages/ai/lib/methods/automatic-function-calling.ts b/packages/ai/lib/methods/automatic-function-calling.ts index e55832d17b..3c0852a6a5 100644 --- a/packages/ai/lib/methods/automatic-function-calling.ts +++ b/packages/ai/lib/methods/automatic-function-calling.ts @@ -25,10 +25,17 @@ import { GenerateContentStreamResult, EnhancedGenerateContentResponse, SingleRequestOptions, + TemplateFunctionDeclaration, + TemplateTool, Tool, } from '../types'; import { ApiSettings } from '../types/internal'; -import { generateContent, generateContentStream } from './generate-content'; +import { + generateContent, + generateContentStream, + templateGenerateContent, + templateGenerateContentStream, +} from './generate-content'; const DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS = 10; @@ -42,6 +49,12 @@ export interface AutomaticFunctionCallingStreamResult { addedContents: Content[]; } +export interface TemplateAutomaticFunctionCallingRequest { + contents: Content[]; + tools?: TemplateTool[]; + [key: string]: unknown; +} + export async function generateContentWithAutomaticFunctionCalling( apiSettings: ApiSettings, model: string, @@ -87,6 +100,56 @@ export async function generateContentWithAutomaticFunctionCalling( return { result: currentResult, addedContents }; } +export async function templateGenerateContentWithAutomaticFunctionCalling( + apiSettings: ApiSettings, + templateId: string, + params: TemplateAutomaticFunctionCallingRequest, + result: GenerateContentResult, + requestOptions?: SingleRequestOptions, +): Promise { + let remainingFunctionCalls = + requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; + let currentParams = params; + let currentResult = result; + const addedContents: Content[] = []; + + while (remainingFunctionCalls > 0) { + const functionCalls = currentResult.response.functionCalls?.(); + if (!functionCalls?.length) { + return { result: currentResult, addedContents }; + } + + const functionResponses = await callFunctionReferences(currentParams.tools, functionCalls); + if (!functionResponses) { + return { result: currentResult, addedContents }; + } + + const responseContent = getModelResponseContent(currentResult); + if (!responseContent) { + return { result: currentResult, addedContents }; + } + + remainingFunctionCalls -= 1; + const functionResponseContent: Content = { + role: 'function', + parts: functionResponses.map(functionResponse => ({ functionResponse })), + }; + addedContents.push(responseContent, functionResponseContent); + currentParams = { + ...currentParams, + contents: [...currentParams.contents, responseContent, functionResponseContent], + }; + currentResult = await templateGenerateContent( + apiSettings, + templateId, + currentParams, + requestOptions, + ); + } + + return { result: currentResult, addedContents }; +} + export async function generateContentStreamWithAutomaticFunctionCalling( apiSettings: ApiSettings, model: string, @@ -137,6 +200,61 @@ export async function generateContentStreamWithAutomaticFunctionCalling( return { result: currentResult, addedContents }; } +export async function templateGenerateContentStreamWithAutomaticFunctionCalling( + apiSettings: ApiSettings, + templateId: string, + params: TemplateAutomaticFunctionCallingRequest, + result: GenerateContentStreamResult, + requestOptions?: SingleRequestOptions, +): Promise { + if (!getFunctionDeclarationsWithReferences(params.tools).length) { + return { result, addedContents: [] }; + } + + let remainingFunctionCalls = + requestOptions?.maxSequentialFunctionCalls ?? DEFAULT_MAX_SEQUENTIAL_FUNCTION_CALLS; + let currentParams = params; + let currentResult = result; + const addedContents: Content[] = []; + + while (remainingFunctionCalls > 0) { + const response = await currentResult.response; + const functionCalls = response.functionCalls?.(); + if (!functionCalls?.length) { + return { result: currentResult, addedContents }; + } + + const functionResponses = await callFunctionReferences(currentParams.tools, functionCalls); + if (!functionResponses) { + return { result: currentResult, addedContents }; + } + + const responseContent = getModelResponseContent(response); + if (!responseContent) { + return { result: currentResult, addedContents }; + } + + remainingFunctionCalls -= 1; + const functionResponseContent: Content = { + role: 'function', + parts: functionResponses.map(functionResponse => ({ functionResponse })), + }; + addedContents.push(responseContent, functionResponseContent); + currentParams = { + ...currentParams, + contents: [...currentParams.contents, responseContent, functionResponseContent], + }; + currentResult = await templateGenerateContentStream( + apiSettings, + templateId, + currentParams, + requestOptions, + ); + } + + return { result: currentResult, addedContents }; +} + function getModelResponseContent( responseOrResult: GenerateContentResult | EnhancedGenerateContentResponse, ): Content | undefined { @@ -152,7 +270,7 @@ function getModelResponseContent( } async function callFunctionReferences( - tools: Tool[] | undefined, + tools: Tool[] | TemplateTool[] | undefined, functionCalls: FunctionCall[], ): Promise { const declarations = getFunctionDeclarationsWithReferences(tools); @@ -177,7 +295,9 @@ async function callFunctionReferences( return functionResponses; } -function getFunctionDeclarationsWithReferences(tools: Tool[] | undefined): FunctionDeclaration[] { +function getFunctionDeclarationsWithReferences( + tools: Tool[] | TemplateTool[] | undefined, +): Array { return ( tools?.flatMap(tool => 'functionDeclarations' in tool diff --git a/packages/ai/lib/methods/chat-session.ts b/packages/ai/lib/methods/chat-session.ts index 7d85038998..c7579a9e86 100644 --- a/packages/ai/lib/methods/chat-session.ts +++ b/packages/ai/lib/methods/chat-session.ts @@ -25,7 +25,6 @@ import { SingleRequestOptions, StartChatParams, StartTemplateChatParams, - EnhancedGenerateContentResponse, } from '../types'; import { formatNewContent } from '../requests/request-helpers'; import { formatBlockErrorMessage } from '../requests/response-helpers'; @@ -39,6 +38,9 @@ import { import { generateContentStreamWithAutomaticFunctionCalling, generateContentWithAutomaticFunctionCalling, + TemplateAutomaticFunctionCallingRequest, + templateGenerateContentStreamWithAutomaticFunctionCalling, + templateGenerateContentWithAutomaticFunctionCalling, } from './automatic-function-calling'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; @@ -271,17 +273,26 @@ export class TemplateChatSession extends ChatSessionBase - templateGenerateContent( + .then(async () => { + const requestOptions = mergeRequestOptions(this.requestOptions, singleRequestOptions); + const result = await templateGenerateContent( this._apiSettings, this.params.templateId, templateParams, - mergeRequestOptions(this.requestOptions, singleRequestOptions), - ), - ) - .then((result: GenerateContentResult) => { + requestOptions, + ); + return templateGenerateContentWithAutomaticFunctionCalling( + this._apiSettings, + this.params.templateId, + templateParams, + result, + requestOptions, + ); + }) + .then(({ result, addedContents }) => { if (result.response.candidates && result.response.candidates.length > 0) { this._history.push(newContent); + this._history.push(...addedContents); const responseContent: Content = { parts: result.response.candidates?.[0]?.content.parts || [], // Response seems to come back without a role set. @@ -314,11 +325,20 @@ export class TemplateChatSession extends ChatSessionBase + templateGenerateContentStreamWithAutomaticFunctionCalling( + this._apiSettings, + this.params.templateId, + templateParams, + result, + requestOptions, + ), ); // Add onto the chain. @@ -329,10 +349,13 @@ export class TemplateChatSession extends ChatSessionBase { throw new Error(SILENT_ERROR); }) - .then(streamResult => streamResult.response) - .then((response: EnhancedGenerateContentResponse) => { + .then(({ result, addedContents }) => + result.response.then(response => ({ response, addedContents })), + ) + .then(({ response, addedContents }) => { if (response.candidates && response.candidates.length > 0) { this._history.push(newContent); + this._history.push(...addedContents); const responseContent = { ...response.candidates[0]?.content }; // Response seems to come back without a role set. if (!responseContent.role) { @@ -358,10 +381,10 @@ export class TemplateChatSession extends ChatSessionBase