From bbdaefed2d921c013c5fb24644f2e2368a768093 Mon Sep 17 00:00:00 2001 From: Fatima AlSaadeh Date: Mon, 17 Nov 2025 12:27:35 +0000 Subject: [PATCH 1/3] fix: maintain formatting and add more context like selection --- examples/superdoc-ai-quickstart/package.json | 2 +- packages/ai/package.json | 2 +- packages/ai/src/ai-actions-service.test.ts | 73 +++++--- packages/ai/src/ai-actions-service.ts | 45 +++-- packages/ai/src/ai-actions.test.ts | 94 +++++++++- packages/ai/src/ai-actions.ts | 131 +++++++++++++- packages/ai/src/editor-adapter.test.ts | 34 ++-- packages/ai/src/editor-adapter.ts | 174 ++++++++++++++++++- packages/ai/src/prompts.ts | 63 +++++-- packages/ai/src/types.ts | 45 +++++ 10 files changed, 571 insertions(+), 92 deletions(-) diff --git a/examples/superdoc-ai-quickstart/package.json b/examples/superdoc-ai-quickstart/package.json index defebf60ab..73ecaf8b13 100644 --- a/examples/superdoc-ai-quickstart/package.json +++ b/examples/superdoc-ai-quickstart/package.json @@ -9,7 +9,7 @@ "preview": "vite preview" }, "dependencies": { - "@superdoc-dev/ai": "^0.1.3", + "@superdoc-dev/ai": "latest", "superdoc": "0.28.0" }, "devDependencies": { diff --git a/packages/ai/package.json b/packages/ai/package.json index 776a566433..f9bb476a88 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -25,7 +25,7 @@ "superdoc": "*" }, "devDependencies": { - "superdoc": "^0.28.0", + "superdoc": "^0.29.0", "@types/node": "^20.0.0", "typescript": "^5.0.0", "eslint": "^8.0.0", diff --git a/packages/ai/src/ai-actions-service.test.ts b/packages/ai/src/ai-actions-service.test.ts index 466b4a2117..41eb35e9fe 100644 --- a/packages/ai/src/ai-actions-service.test.ts +++ b/packages/ai/src/ai-actions-service.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { AIActionsService } from './ai-actions-service'; -import type { AIProvider, Editor } from './types'; +import type { AIProvider, ContextWindow, Editor } from './types'; const createChain = (commands?: any) => { const chainApi = { @@ -40,6 +40,23 @@ const createChain = (commands?: any) => { return { chainFn, chainApi }; }; +const createContextProvider = + (overrides?: Partial) => + () => ({ + scope: 'document', + primaryText: 'Sample document text for testing', + selection: undefined, + documentStats: { + wordCount: 5, + charCount: 32, + }, + metadata: {documentId: 'doc-123'}, + ...overrides, + }); + +const defaultContextProvider = createContextProvider(); +const emptyContextProvider = createContextProvider({primaryText: ''}); + describe('AIActionsService', () => { let mockProvider: AIProvider; let mockEditor: Editor; @@ -117,7 +134,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 6 }]) .mockReturnValueOnce([{ from: 7, to: 15 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.find('find sample'); expect(result.success).toBe(true); @@ -130,7 +147,7 @@ describe('AIActionsService', () => { JSON.stringify({ success: false, results: [] }) ); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.find('find nothing'); expect(result.success).toBe(false); @@ -138,14 +155,14 @@ describe('AIActionsService', () => { }); it('should validate input query', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); await expect(actions.find('')).rejects.toThrow('Query cannot be empty'); await expect(actions.find(' ')).rejects.toThrow('Query cannot be empty'); }); it('should return empty when no document context', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, () => '', false); + const actions = new AIActionsService(mockProvider, mockEditor, emptyContextProvider, false); const result = await actions.find('query'); expect(result).toEqual({ success: false, results: [] }); @@ -170,7 +187,7 @@ describe('AIActionsService', () => { { from: 20, to: 24 } ]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.findAll('find all test'); expect(result.success).toBe(true); @@ -188,7 +205,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 5, to: 17 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.highlight('highlight this'); expect(result.success).toBe(true); @@ -207,7 +224,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); await actions.highlight('highlight', '#FF0000'); expect(chainApi.setHighlight).toHaveBeenCalledWith('#FF0000'); @@ -222,7 +239,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.highlight('highlight'); expect(result.success).toBe(false); @@ -242,7 +259,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 3 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.replace('replace old with new'); expect(result.success).toBe(true); @@ -250,7 +267,7 @@ describe('AIActionsService', () => { }); it('should validate input', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); await expect(actions.replace('')).rejects.toThrow('Query cannot be empty'); }); @@ -271,7 +288,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 3 }]) .mockReturnValueOnce([{ from: 10, to: 13 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.replaceAll('replace all old with new'); expect(result.success).toBe(true); @@ -291,7 +308,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 8 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertTrackedChange('suggest change'); expect(result.success).toBe(true); @@ -315,7 +332,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 5 }]) .mockReturnValueOnce([{ from: 10, to: 16 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertTrackedChanges('suggest multiple changes'); expect(result.success).toBe(true); @@ -335,7 +352,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertComment('add comment'); expect(result.success).toBe(true); @@ -360,7 +377,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 5 }]) .mockReturnValueOnce([{ from: 10, to: 15 }]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertComments('add multiple comments'); expect(result.success).toBe(true); @@ -378,7 +395,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.summarize('summarize this document'); expect(result.success).toBe(true); @@ -386,7 +403,7 @@ describe('AIActionsService', () => { }); it('should return failure when no document context', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, () => '', false); + const actions = new AIActionsService(mockProvider, mockEditor, emptyContextProvider, false); const result = await actions.summarize('summarize'); expect(result).toEqual({ results: [], success: false }); @@ -409,7 +426,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - () => mockEditor.state.doc.textContent, + defaultContextProvider, false, undefined, false @@ -442,7 +459,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - () => mockEditor.state.doc.textContent, + defaultContextProvider, false, onStreamChunk, true @@ -467,7 +484,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertContent('generate introduction'); expect(result.success).toBe(true); @@ -480,13 +497,13 @@ describe('AIActionsService', () => { }); it('should validate input', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); await expect(actions.insertContent('')).rejects.toThrow('Query cannot be empty'); }); it('should return failure when no editor', async () => { - const actions = new AIActionsService(mockProvider, null, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, null, defaultContextProvider, false); const result = await actions.insertContent('insert content'); expect(result).toEqual({ success: false, results: [] }); @@ -500,7 +517,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.insertContent('insert content'); expect(result).toEqual({ success: false, results: [] }); @@ -533,7 +550,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - () => mockEditor.state.doc.textContent, + defaultContextProvider, false, undefined, true @@ -565,7 +582,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - () => mockEditor.state.doc.textContent, + defaultContextProvider, false, undefined, false @@ -587,7 +604,7 @@ describe('AIActionsService', () => { mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); // Test with logging disabled - const actions1 = new AIActionsService(mockProvider, mockEditor, () => 'context', false); + const actions1 = new AIActionsService(mockProvider, mockEditor, createContextProvider({primaryText: 'context'}), false); const response1 = JSON.stringify({ success: true, results: [{ originalText: 'test', suggestedText: 'new' }] @@ -611,7 +628,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([]); - const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); + const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); const result = await actions.replace('replace text'); expect(result.results).toHaveLength(0); diff --git a/packages/ai/src/ai-actions-service.ts b/packages/ai/src/ai-actions-service.ts index 7bbbfa85ab..872a1a3e02 100644 --- a/packages/ai/src/ai-actions-service.ts +++ b/packages/ai/src/ai-actions-service.ts @@ -1,4 +1,13 @@ -import {AIProvider, Editor, Result, FoundMatch, DocumentPosition, AIMessage} from './types'; +import { + AIProvider, + Editor, + Result, + FoundMatch, + DocumentPosition, + AIMessage, + ContextWindow, + ContextScope, +} from './types'; import {EditorAdapter} from './editor-adapter'; import {validateInput, parseJSON} from './utils'; import { @@ -19,7 +28,7 @@ export class AIActionsService { constructor( private provider: AIProvider, private editor: Editor | null, - private documentContextProvider: () => string, + private contextProvider: (scope?: ContextScope) => ContextWindow, private enableLogging: boolean = false, private onStreamChunk?: (partialResult: string) => void, private streamPreference?: boolean, @@ -34,20 +43,20 @@ export class AIActionsService { } } - private getDocumentContext(): string { - if (!this.documentContextProvider) { - return ''; + private getContext(scope?: ContextScope): ContextWindow { + if (!this.contextProvider) { + return {scope: 'document', primaryText: ''}; } try { - return this.documentContextProvider(); + return this.contextProvider(scope); } catch (error) { if (this.enableLogging) { console.error( `Failed to retrieve document context: ${error instanceof Error ? error.message : 'Unknown error'}` ); } - return ''; + return {scope: 'document', primaryText: ''}; } } @@ -63,13 +72,13 @@ export class AIActionsService { throw new Error('Query cannot be empty'); } - const documentContext = this.getDocumentContext(); + const context = this.getContext('document'); - if (!documentContext) { + if (!context.primaryText?.trim()) { return {success: false, results: []}; } - const prompt = buildFindPrompt(query, documentContext, findAll); + const prompt = buildFindPrompt(query, context, findAll); const response = await this.runCompletion([ {role: 'system', content: SYSTEM_PROMPTS.SEARCH}, {role: 'user', content: prompt}, @@ -159,14 +168,14 @@ export class AIActionsService { multiple: boolean, operationFn: (adapter: EditorAdapter, position: DocumentPosition, replacement: FoundMatch) => Promise ): Promise { - const documentContext = this.getDocumentContext(); + const context = this.getContext('document'); - if (!documentContext) { + if (!context.primaryText?.trim()) { return []; } // Get AI query - const prompt = buildReplacePrompt(query, documentContext, multiple); + const prompt = buildReplacePrompt(query, context, multiple); const response = await this.runCompletion([ {role: 'system', content: SYSTEM_PROMPTS.EDIT}, {role: 'user', content: prompt}, @@ -355,12 +364,12 @@ export class AIActionsService { * Generates a summary of the document. */ async summarize(query: string): Promise { - const documentContext = this.getDocumentContext(); + const context = this.getContext('document'); - if (!documentContext) { + if (!context.primaryText?.trim()) { return {results: [], success: false}; } - const prompt = buildSummaryPrompt(query, documentContext); + const prompt = buildSummaryPrompt(query, context); const useStreaming = this.streamPreference !== false; let streamedLength = 0; @@ -399,8 +408,8 @@ export class AIActionsService { return {success: false, results: []}; } - const documentContext = this.getDocumentContext(); - const prompt = buildInsertContentPrompt(query, documentContext); + const context = this.getContext(); + const prompt = buildInsertContentPrompt(query, context); const useStreaming = this.streamPreference !== false; let streamingInsertedLength = 0; diff --git a/packages/ai/src/ai-actions.test.ts b/packages/ai/src/ai-actions.test.ts index 596d015e85..9cf61a7353 100644 --- a/packages/ai/src/ai-actions.test.ts +++ b/packages/ai/src/ai-actions.test.ts @@ -1,13 +1,19 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import {TextSelection} from 'prosemirror-state'; import { AIActions } from './ai-actions'; +import {EditorAdapter} from './editor-adapter'; import type { AIProvider, AIActionsOptions, SuperDoc, Editor } from './types'; describe('AIActions', () => { let mockProvider: AIProvider; let mockEditor: Editor; let mockSuperdoc: SuperDoc; + let textSelectionSpy: any; + let scrollSpy: any; beforeEach(() => { + textSelectionSpy = vi.spyOn(TextSelection, 'create').mockReturnValue({} as any); + scrollSpy = vi.spyOn(EditorAdapter.prototype as any, 'scrollToPosition').mockImplementation(() => {}); mockProvider = { async *streamCompletion(messages, options) { yield 'chunk1'; @@ -21,15 +27,36 @@ describe('AIActions', () => { mockEditor = { state: { + selection: { + from: 0, + to: 6, + empty: false, + }, doc: { textContent: 'Sample document text', content: { size: 100 }, - resolve: vi.fn((pos) => ({ - pos, + textBetween: vi.fn(() => 'Sample'), + resolve: vi.fn((pos) => ({ + pos, + depth: 0, parent: { inlineContent: true }, - min: vi.fn(() => pos), - max: vi.fn(() => pos) - })) + nodeAfter: null, + nodeBefore: null, + start: vi.fn(() => pos), + end: vi.fn(() => pos), + })), + descendants: vi.fn((cb) => { + cb( + { + isBlock: true, + textContent: 'Sample paragraph text', + type: {name: 'paragraph'}, + attrs: {}, + nodeSize: 20, + }, + 0, + ); + }), }, tr: { setSelection: vi.fn().mockReturnThis(), @@ -64,6 +91,11 @@ describe('AIActions', () => { } as any; }); + afterEach(() => { + textSelectionSpy?.mockRestore(); + scrollSpy?.mockRestore(); + }); + describe('constructor', () => { it('should initialize with provider config', async () => { const options: AIActionsOptions = { @@ -261,6 +293,56 @@ describe('AIActions', () => { }); }); + describe('context window', () => { + it('should expose selection-scoped context by default', async () => { + const ai = new AIActions(mockSuperdoc, { + user: {displayName: 'AI Bot'}, + provider: mockProvider, + }); + await ai.waitUntilReady(); + + const contextWindow = ai.getContextWindow(); + expect(contextWindow.scope).toBe('selection'); + expect(contextWindow.primaryText).toBeTruthy(); + }); + + it('should strip internal context options before calling providers', async () => { + const ai = new AIActions(mockSuperdoc, { + user: {displayName: 'AI Bot'}, + provider: mockProvider, + }); + await ai.waitUntilReady(); + + const streamSpy = vi.spyOn(mockProvider, 'streamCompletion'); + await ai.streamCompletion('prompt', {contextScope: 'document', contextPaddingBlocks: 2}); + + expect(streamSpy).toHaveBeenCalledWith( + expect.any(Array), + expect.not.objectContaining({ + contextScope: expect.anything(), + }), + ); + }); + + it('should not forward context overrides for non-streaming completions', async () => { + const ai = new AIActions(mockSuperdoc, { + user: {displayName: 'AI Bot'}, + provider: mockProvider, + }); + await ai.waitUntilReady(); + + const completionSpy = vi.spyOn(mockProvider, 'getCompletion'); + await ai.getCompletion('prompt', {contextScope: 'document'}); + + expect(completionSpy).toHaveBeenCalledWith( + expect.any(Array), + expect.not.objectContaining({ + contextScope: expect.anything(), + }), + ); + }); + }); + describe('streamCompletion', () => { it('should stream completion chunks', async () => { const onStreamingStart = vi.fn(); diff --git a/packages/ai/src/ai-actions.ts b/packages/ai/src/ai-actions.ts index dab2549c2c..0543ef1a49 100644 --- a/packages/ai/src/ai-actions.ts +++ b/packages/ai/src/ai-actions.ts @@ -1,5 +1,8 @@ import type { CompletionOptions, + ContextScope, + ContextWindow, + ContextWindowConfig, Editor, Result, StreamOptions, @@ -11,6 +14,8 @@ import type { } from './types'; import {AIActionsService} from './ai-actions-service'; import {createAIProvider, isAIProvider} from './providers'; +import {EditorAdapter} from './editor-adapter'; +import {formatContextWindow} from './prompts'; /** * Primary entry point for SuperDoc AI capabilities. Wraps a SuperDoc instance, @@ -45,6 +50,10 @@ export class AIActions { private isReady = false; private initializationPromise: Promise | null = null; private readonly commands: AIActionsService; + private readonly contextWindowConfig: { + paddingBlocks: number; + maxChars: number; + }; public readonly action = { find: async (instruction: string) => { @@ -106,6 +115,12 @@ export class AIActions { provider: aiProvider, }; + const contextWindowDefaults: ContextWindowConfig | undefined = this.config.contextWindow; + this.contextWindowConfig = { + paddingBlocks: Math.max(0, contextWindowDefaults?.paddingBlocks ?? 1), + maxChars: Math.max(200, contextWindowDefaults?.maxChars ?? 2000), + }; + this.callbacks = { onReady, onStreamingStart, @@ -130,7 +145,7 @@ export class AIActions { this.commands = new AIActionsService( this.config.provider, editor, - () => this.getDocumentContext(), + (scope) => this.getContextWindow({scope}), this.config.enableLogging, (partial) => this.callbacks.onStreamingPartialResult?.({partialResult: partial}), streamResults, @@ -242,8 +257,11 @@ export class AIActions { throw new Error('AIActions is not ready yet. Call waitUntilReady() first.'); } - const documentContext = this.getDocumentContext(); - const userContent = documentContext ? `${prompt}\n\nDocument context:\n${documentContext}` : prompt; + const context = this.getContextWindow({ + scope: options?.contextScope, + paddingBlocks: options?.contextPaddingBlocks, + }); + const userContent = this.buildPromptWithContext(prompt, context); const messages = [ {role: 'system' as const, content: this.config.systemPrompt || ''}, @@ -251,11 +269,16 @@ export class AIActions { ]; let accumulated = ''; + const providerOptions = options ? {...options} : undefined; + if (providerOptions) { + delete (providerOptions as Partial).contextScope; + delete (providerOptions as Partial).contextPaddingBlocks; + } try { this.callbacks.onStreamingStart?.(); - const stream = this.config.provider.streamCompletion(messages, options); + const stream = this.config.provider.streamCompletion(messages, providerOptions); for await (const chunk of stream) { accumulated += chunk; @@ -284,16 +307,25 @@ export class AIActions { throw new Error('AIActions is not ready yet. Call waitUntilReady() first.'); } - const documentContext = this.getDocumentContext(); - const userContent = documentContext ? `${prompt}\n\nDocument context:\n${documentContext}` : prompt; + const context = this.getContextWindow({ + scope: options?.contextScope, + paddingBlocks: options?.contextPaddingBlocks, + }); + const userContent = this.buildPromptWithContext(prompt, context); const messages = [ {role: 'system' as const, content: this.config.systemPrompt || ''}, {role: 'user' as const, content: userContent}, ]; + const providerOptions = options ? {...options} : undefined; + if (providerOptions) { + delete (providerOptions as Partial).contextScope; + delete (providerOptions as Partial).contextPaddingBlocks; + } + try { - return await this.config.provider.getCompletion(messages, options); + return await this.config.provider.getCompletion(messages, providerOptions); } catch (error) { this.handleError(error as Error); throw error; @@ -307,12 +339,95 @@ export class AIActions { * @returns Document context string */ public getDocumentContext(): string { + return this.getContextWindow({scope: 'document'}).primaryText; + } + + /** + * Returns a scoped context window summarizing the current selection and neighbors. + */ + public getContextWindow(options?: {scope?: ContextScope; paddingBlocks?: number}): ContextWindow { + const rawWindow = this.buildContextWindow(options); + return this.applyContextConstraints(rawWindow); + } + + private buildContextWindow(options?: {scope?: ContextScope; paddingBlocks?: number}): ContextWindow { const editor = this.getEditor(); + const scope = options?.scope; + if (!editor) { + return { + scope: scope ?? 'document', + primaryText: '', + }; + } + + const adapter = new EditorAdapter(editor); + const paddingBlocks = this.resolvePaddingBlocks(options?.paddingBlocks); + + return adapter.getContextWindow(paddingBlocks, scope); + } + + private resolvePaddingBlocks(padding?: number): number { + if (typeof padding === 'number' && padding >= 0) { + return padding; + } + + return this.contextWindowConfig.paddingBlocks; + } + + private applyContextConstraints(context: ContextWindow): ContextWindow { + const clamp = (value?: string): string | undefined => { + if (!value) { + return value; + } + + const limit = this.contextWindowConfig.maxChars; + if (!limit || value.length <= limit) { + return value; + } + + return `${value.slice(0, limit)}...`; + }; + + const selection = context.selection + ? { + ...context.selection, + text: clamp(context.selection.text) ?? '', + block: context.selection.block + ? { + ...context.selection.block, + text: clamp(context.selection.block.text) ?? '', + } + : undefined, + surroundingBlocks: (context.selection.surroundingBlocks || []).map((block) => ({ + ...block, + text: clamp(block.text) ?? '', + })), + } + : undefined; + + return { + ...context, + primaryText: clamp(context.primaryText) ?? '', + selection, + }; + } + + private buildPromptWithContext(prompt: string, context: ContextWindow): string { + const formattedContext = this.serializeContextWindow(context); + if (!formattedContext) { + return prompt; + } + + return `${prompt}\n\nContext window:\n${formattedContext}`; + } + + private serializeContextWindow(context: ContextWindow): string { + if (!context.primaryText?.trim()) { return ''; } - return editor.state?.doc?.textContent?.trim() || ''; + return formatContextWindow(context); } /** diff --git a/packages/ai/src/editor-adapter.test.ts b/packages/ai/src/editor-adapter.test.ts index 6d3885b852..41d734185a 100644 --- a/packages/ai/src/editor-adapter.test.ts +++ b/packages/ai/src/editor-adapter.test.ts @@ -195,14 +195,18 @@ describe('EditorAdapter', () => { await mockAdapter.replaceText(0, 5, 'hello'); expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); + expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith({ - type: 'text', - text: 'hello', - marks: [ - { type: 'bold', attrs: {} }, - { type: 'textStyle', attrs: { fontSize: '14pt' } } - ] + type: 'paragraph', + content: [{ + type: 'text', + text: 'hello', + marks: [ + { type: 'bold', attrs: {} }, + { type: 'textStyle', attrs: { fontSize: '14pt' } } + ] + }] }); }); @@ -211,19 +215,25 @@ describe('EditorAdapter', () => { await mockAdapter.replaceText(0, 5, 'hello'); + expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); + expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); + expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith('hello'); }); }); describe('createTrackedChange', () => { - it('should create tracked change with author', async () => { - + it('should create tracked change without marks', async () => { mockEditor.commands.getSelectionMarks = vi.fn().mockReturnValue([]); - const changeId = await mockAdapter.createTrackedChange(0, 5, 'new'); + const changeId = await mockAdapter.createTrackedChange(0, 5, 'new'); expect(changeId).toMatch(/^tracked-change-/); expect(mockEditor.commands.enableTrackChanges).toHaveBeenCalled(); + expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); + expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); + expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); + expect(mockEditor.commands.insertContent).toHaveBeenCalledWith('new'); expect(mockEditor.commands.disableTrackChanges).toHaveBeenCalled(); }); @@ -236,13 +246,17 @@ describe('EditorAdapter', () => { await mockAdapter.createTrackedChange(0, 5, 'new'); + expect(mockEditor.commands.enableTrackChanges).toHaveBeenCalled(); + expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); + expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); + expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith({ type: 'text', text: 'new', marks: [{ type: 'italic', attrs: {} }] }); + expect(mockEditor.commands.disableTrackChanges).toHaveBeenCalled(); }); - }); describe('createComment', () => { diff --git a/packages/ai/src/editor-adapter.ts b/packages/ai/src/editor-adapter.ts index 1dc1ec20b5..a2385dd2f1 100644 --- a/packages/ai/src/editor-adapter.ts +++ b/packages/ai/src/editor-adapter.ts @@ -1,7 +1,20 @@ -import type { Editor, FoundMatch, MarkType } from './types'; -import {generateId} from "./utils"; +import type { + ContextBlock, + ContextScope, + ContextWindow, + Editor, + FoundMatch, + MarkType, + SelectionContext, +} from './types'; +import { generateId } from './utils'; import { TextSelection } from 'prosemirror-state'; +type TemplateNode = { + marks: MarkType[]; + length: number; +}; + /** * Adapter for SuperDoc editor operations * Encapsulates all editor-specific API calls @@ -59,13 +72,17 @@ export class EditorAdapter { this.editor.commands.deleteSelection(); if (marks.length > 0) { this.editor.commands.insertContent({ - type: 'text', - text: suggestedText, - marks: marks.map((mark: MarkType) => ({ - type: mark.type.name, - attrs: mark.attrs, - })), - }); + type: 'paragraph', + content: [{ + type: 'text', + text: suggestedText, + marks: marks.map((mark: MarkType) => ({ + type: mark.type.name, + attrs: mark.attrs, + })), + }], + }); + } else { this.editor.commands.insertContent(suggestedText); } @@ -133,4 +150,143 @@ export class EditorAdapter { })), }); } + + + getContextWindow(paddingBlocks: number = 1, scopeOverride?: ContextScope): ContextWindow { + const state = this.editor?.state; + const doc = state?.doc; + + if (!state || !doc) { + return { + scope: 'document', + primaryText: '', + }; + } + + const selection = state.selection; + const blocks = this.collectBlocks(); + const selectionInfo = this.buildSelectionContext(blocks, paddingBlocks); + const documentText = doc.textContent?.trim() ?? ''; + + const hasSelection = Boolean(selection && !selection.empty && selectionInfo?.text !== undefined); + const derivedPrimaryText = hasSelection + ? (selectionInfo?.text ?? '') + : (selectionInfo?.block?.text ?? documentText); + const primaryText = scopeOverride === 'document' ? documentText : derivedPrimaryText; + + const scope: ContextScope = scopeOverride + ? scopeOverride + : hasSelection + ? 'selection' + : selectionInfo?.block + ? 'block' + : 'document'; + + const metadata: Record = { + documentId: (this.editor.options as any)?.documentId, + }; + + return { + scope, + primaryText: primaryText ?? '', + selection: scopeOverride === 'document' ? undefined : (selectionInfo ?? undefined), + documentStats: { + wordCount: documentText ? documentText.split(/\s+/).filter(Boolean).length : 0, + charCount: documentText.length, + }, + metadata, + }; + } + + private collectBlocks(): ContextBlock[] { + const {doc} = this.editor.state; + const blocks: ContextBlock[] = []; + doc.descendants((node: any, pos: any) => { + if (!node.isBlock) { + return true; + } + + const text = node.textContent?.trim() ?? ''; + + blocks.push({ + type: node.type.name, + text, + from: pos, + to: pos + node.nodeSize, + attrs: node.attrs, + headingLevel: typeof node.attrs?.level === 'number' ? node.attrs.level : undefined, + title: typeof node.attrs?.title === 'string' ? node.attrs.title : undefined, + }); + + return true; + }); + + return blocks; + } + + private buildSelectionContext( + blocks: ContextBlock[], + paddingBlocks: number + ): SelectionContext | null { + if (!this.editor.state?.selection) { + return null; + } + + const selection = this.editor.state.selection; + const doc = this.editor.state.doc; + const text = + selection && !selection.empty ? doc.textBetween(selection.from, selection.to, '\n\n', '\n\n') : ''; + + const blockIndex = this.findBlockIndex(blocks, selection.from); + const block = blockIndex >= 0 ? blocks[blockIndex] : undefined; + const surroundingBlocks = blockIndex >= 0 ? this.getNeighborBlocks(blocks, blockIndex, paddingBlocks) : []; + const activeMarks = (this.editor.commands.getSelectionMarks?.() as MarkType[]) || []; + + const metadata: Record = {}; + if (block?.title) { + metadata.clauseTitle = block.title; + } + if (typeof block?.headingLevel === 'number') { + metadata.headingLevel = block.headingLevel; + } + + if (!text && !block) { + return null; + } + + return { + from: selection.from, + to: selection.to, + text, + normalizedText: text?.trim?.(), + block, + surroundingBlocks, + activeMarks: activeMarks.map((mark) => ({ + type: mark.type.name, + attrs: mark.attrs, + })), + metadata: Object.keys(metadata).length ? metadata : undefined, + }; + } + + private findBlockIndex(blocks: ContextBlock[], position: number): number { + return blocks.findIndex((block) => position >= block.from && position < block.to); + } + + private getNeighborBlocks(blocks: ContextBlock[], index: number, paddingBlocks: number): ContextBlock[] { + if (index < 0) { + return []; + } + + const neighbors: ContextBlock[] = []; + for (let i = Math.max(0, index - paddingBlocks); i <= Math.min(blocks.length - 1, index + paddingBlocks); i++) { + if (i === index) { + continue; + } + + neighbors.push(blocks[i]); + } + + return neighbors; + } } diff --git a/packages/ai/src/prompts.ts b/packages/ai/src/prompts.ts index 1a2def7760..348768144a 100644 --- a/packages/ai/src/prompts.ts +++ b/packages/ai/src/prompts.ts @@ -1,3 +1,5 @@ +import type {ContextBlock, ContextWindow} from './types'; + /** * AI prompt templates for document operations */ @@ -9,12 +11,50 @@ export const SYSTEM_PROMPTS = { CONTENT_GENERATION: 'You are a document content generation assistant. Always respond with valid JSON.', } as const; -export const buildFindPrompt = (query: string, documentContext: string, findAll: boolean): string => { +const describeBlock = (block: ContextBlock): string => { + const descriptor: string[] = [block.type]; + if (block.title) { + descriptor.push(`"${block.title}"`); + } + if (typeof block.headingLevel === 'number') { + descriptor.push(`(level ${block.headingLevel})`); + } + return `${descriptor.join(' ')}:\n${block.text}`; +}; + +export const formatContextWindow = (context: ContextWindow): string => { + const segments: string[] = [`Scope: ${context.scope}`]; + + if (context.selection?.text) { + segments.push(`Selected text:\n${context.selection.text}`); + } + + if (context.selection?.block?.text) { + segments.push(`Active block:\n${describeBlock(context.selection.block)}`); + } + + if (context.selection?.surroundingBlocks?.length) { + const nearby = context.selection.surroundingBlocks.map((block) => describeBlock(block)).join('\n---\n'); + segments.push(`Surrounding blocks:\n${nearby}`); + } + + if (!context.selection?.text && context.primaryText) { + segments.push(`Primary text:\n${context.primaryText}`); + } + + if (context.metadata?.documentId) { + segments.push(`Document ID: ${context.metadata.documentId}`); + } + + return segments.filter(Boolean).join('\n\n'); +}; + +export const buildFindPrompt = (query: string, context: ContextWindow, findAll: boolean): string => { const scope = findAll ? 'ALL occurrences' : 'FIRST occurrence ONLY'; return `apply this query for original text return the EXACT text from the doc no title or added text: ${query}, ${scope} - Document context: - ${documentContext} + Context window: + ${formatContextWindow(context)} Respond with JSON: { @@ -26,7 +66,7 @@ export const buildFindPrompt = (query: string, documentContext: string, findAll: }`; }; -export const buildReplacePrompt = (query: string, documentContext: string, replaceAll: boolean): string => { +export const buildReplacePrompt = (query: string, context: ContextWindow, replaceAll: boolean): string => { const scope = replaceAll ? 'ALL occurrences' : 'FIRST occurrence ONLY'; @@ -34,8 +74,8 @@ export const buildReplacePrompt = (query: string, documentContext: string, repla const finalQuery = `apply this query: ${query} if find and replace query then Find and replace the EXACT text of ${scope}`; return `${finalQuery} - Document context: - ${documentContext} + Context window: + ${formatContextWindow(context)} Respond with JSON: { @@ -47,10 +87,10 @@ export const buildReplacePrompt = (query: string, documentContext: string, repla }`; }; -export const buildSummaryPrompt = (query: string, documentContext: string): string => { +export const buildSummaryPrompt = (query: string, context: ContextWindow): string => { return `${query} - Document context: - ${documentContext} + Context window: + ${formatContextWindow(context)} Respond with JSON: { @@ -62,9 +102,10 @@ export const buildSummaryPrompt = (query: string, documentContext: string): stri }`; }; -export const buildInsertContentPrompt = (query: string, documentContext?: string): string => { +export const buildInsertContentPrompt = (query: string, context: ContextWindow): string => { return `${query} - ${documentContext ? `Current document:\n${documentContext}\n` : ''} + Context window: + ${formatContextWindow(context)} Respond with JSON: { "success": boolean, "results": [ { "suggestedText": string, diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index caad28469b..7812eac4cc 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -6,6 +6,45 @@ export type NodeType = Node; export type Editor = InstanceType; export type SuperDocInstance = typeof SuperDoc | SuperDoc; +export type ContextScope = 'selection' | 'block' | 'document'; + +export type ContextBlock = { + type: string; + text: string; + from: number; + to: number; + attrs?: Record; + headingLevel?: number; + title?: string; +}; + +export type SelectionContext = { + from: number; + to: number; + text: string; + normalizedText?: string; + block?: ContextBlock; + surroundingBlocks: ContextBlock[]; + activeMarks?: {type: string; attrs?: Record}[]; + metadata?: Record; +}; + +export type ContextWindow = { + scope: ContextScope; + primaryText: string; + selection?: SelectionContext; + documentStats?: { + wordCount: number; + charCount: number; + }; + metadata?: Record; +}; + +export type ContextWindowConfig = { + paddingBlocks?: number; + maxChars?: number; +}; + /** * Represents a position range in the document */ @@ -61,6 +100,10 @@ export type StreamOptions = { documentId?: string; /** Force streaming (true) or disable it (false). Defaults to true when supported. */ stream?: boolean; + /** Override the context scope used when building prompts */ + contextScope?: ContextScope; + /** Override the number of surrounding blocks included in the context */ + contextPaddingBlocks?: number; } /** @@ -104,6 +147,8 @@ export type AIActionsConfig = { systemPrompt?: string; /** Enable debug logging */ enableLogging?: boolean; + /** Context window configuration */ + contextWindow?: ContextWindowConfig; } /** From 9184bf4bb7432fb7f52e614b33f4465c4ae9f55d Mon Sep 17 00:00:00 2001 From: Fatima AlSaadeh Date: Mon, 17 Nov 2025 13:16:16 +0000 Subject: [PATCH 2/3] fix: consider the selection context --- packages/ai/src/ai-actions-service.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/ai/src/ai-actions-service.ts b/packages/ai/src/ai-actions-service.ts index 872a1a3e02..f190858425 100644 --- a/packages/ai/src/ai-actions-service.ts +++ b/packages/ai/src/ai-actions-service.ts @@ -72,7 +72,7 @@ export class AIActionsService { throw new Error('Query cannot be empty'); } - const context = this.getContext('document'); + const context = this.getContext(); if (!context.primaryText?.trim()) { return {success: false, results: []}; @@ -168,7 +168,7 @@ export class AIActionsService { multiple: boolean, operationFn: (adapter: EditorAdapter, position: DocumentPosition, replacement: FoundMatch) => Promise ): Promise { - const context = this.getContext('document'); + const context = this.getContext(); if (!context.primaryText?.trim()) { return []; From 946e751ac2c2995098de1a1cb4838182a29015f7 Mon Sep 17 00:00:00 2001 From: Fatima AlSaadeh Date: Mon, 17 Nov 2025 15:15:07 +0000 Subject: [PATCH 3/3] feat: add selected context to each ai action --- packages/ai/src/ai-actions-service.test.ts | 73 +++----- packages/ai/src/ai-actions-service.ts | 74 ++++---- packages/ai/src/ai-actions.test.ts | 94 +--------- packages/ai/src/ai-actions.ts | 185 +++++++------------- packages/ai/src/editor-adapter.test.ts | 34 ++-- packages/ai/src/editor-adapter.ts | 190 +++------------------ packages/ai/src/prompts.ts | 63 ++----- packages/ai/src/types.ts | 45 ----- 8 files changed, 185 insertions(+), 573 deletions(-) diff --git a/packages/ai/src/ai-actions-service.test.ts b/packages/ai/src/ai-actions-service.test.ts index 41eb35e9fe..466b4a2117 100644 --- a/packages/ai/src/ai-actions-service.test.ts +++ b/packages/ai/src/ai-actions-service.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { AIActionsService } from './ai-actions-service'; -import type { AIProvider, ContextWindow, Editor } from './types'; +import type { AIProvider, Editor } from './types'; const createChain = (commands?: any) => { const chainApi = { @@ -40,23 +40,6 @@ const createChain = (commands?: any) => { return { chainFn, chainApi }; }; -const createContextProvider = - (overrides?: Partial) => - () => ({ - scope: 'document', - primaryText: 'Sample document text for testing', - selection: undefined, - documentStats: { - wordCount: 5, - charCount: 32, - }, - metadata: {documentId: 'doc-123'}, - ...overrides, - }); - -const defaultContextProvider = createContextProvider(); -const emptyContextProvider = createContextProvider({primaryText: ''}); - describe('AIActionsService', () => { let mockProvider: AIProvider; let mockEditor: Editor; @@ -134,7 +117,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 6 }]) .mockReturnValueOnce([{ from: 7, to: 15 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.find('find sample'); expect(result.success).toBe(true); @@ -147,7 +130,7 @@ describe('AIActionsService', () => { JSON.stringify({ success: false, results: [] }) ); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.find('find nothing'); expect(result.success).toBe(false); @@ -155,14 +138,14 @@ describe('AIActionsService', () => { }); it('should validate input query', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); await expect(actions.find('')).rejects.toThrow('Query cannot be empty'); await expect(actions.find(' ')).rejects.toThrow('Query cannot be empty'); }); it('should return empty when no document context', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, emptyContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => '', false); const result = await actions.find('query'); expect(result).toEqual({ success: false, results: [] }); @@ -187,7 +170,7 @@ describe('AIActionsService', () => { { from: 20, to: 24 } ]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.findAll('find all test'); expect(result.success).toBe(true); @@ -205,7 +188,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 5, to: 17 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.highlight('highlight this'); expect(result.success).toBe(true); @@ -224,7 +207,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); await actions.highlight('highlight', '#FF0000'); expect(chainApi.setHighlight).toHaveBeenCalledWith('#FF0000'); @@ -239,7 +222,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.highlight('highlight'); expect(result.success).toBe(false); @@ -259,7 +242,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 3 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.replace('replace old with new'); expect(result.success).toBe(true); @@ -267,7 +250,7 @@ describe('AIActionsService', () => { }); it('should validate input', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); await expect(actions.replace('')).rejects.toThrow('Query cannot be empty'); }); @@ -288,7 +271,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 3 }]) .mockReturnValueOnce([{ from: 10, to: 13 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.replaceAll('replace all old with new'); expect(result.success).toBe(true); @@ -308,7 +291,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 8 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertTrackedChange('suggest change'); expect(result.success).toBe(true); @@ -332,7 +315,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 5 }]) .mockReturnValueOnce([{ from: 10, to: 16 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertTrackedChanges('suggest multiple changes'); expect(result.success).toBe(true); @@ -352,7 +335,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertComment('add comment'); expect(result.success).toBe(true); @@ -377,7 +360,7 @@ describe('AIActionsService', () => { .mockReturnValueOnce([{ from: 0, to: 5 }]) .mockReturnValueOnce([{ from: 10, to: 15 }]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertComments('add multiple comments'); expect(result.success).toBe(true); @@ -395,7 +378,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.summarize('summarize this document'); expect(result.success).toBe(true); @@ -403,7 +386,7 @@ describe('AIActionsService', () => { }); it('should return failure when no document context', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, emptyContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => '', false); const result = await actions.summarize('summarize'); expect(result).toEqual({ results: [], success: false }); @@ -426,7 +409,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - defaultContextProvider, + () => mockEditor.state.doc.textContent, false, undefined, false @@ -459,7 +442,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - defaultContextProvider, + () => mockEditor.state.doc.textContent, false, onStreamChunk, true @@ -484,7 +467,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertContent('generate introduction'); expect(result.success).toBe(true); @@ -497,13 +480,13 @@ describe('AIActionsService', () => { }); it('should validate input', async () => { - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); await expect(actions.insertContent('')).rejects.toThrow('Query cannot be empty'); }); it('should return failure when no editor', async () => { - const actions = new AIActionsService(mockProvider, null, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, null, () => mockEditor.state.doc.textContent, false); const result = await actions.insertContent('insert content'); expect(result).toEqual({ success: false, results: [] }); @@ -517,7 +500,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.insertContent('insert content'); expect(result).toEqual({ success: false, results: [] }); @@ -550,7 +533,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - defaultContextProvider, + () => mockEditor.state.doc.textContent, false, undefined, true @@ -582,7 +565,7 @@ describe('AIActionsService', () => { const actions = new AIActionsService( mockProvider, mockEditor, - defaultContextProvider, + () => mockEditor.state.doc.textContent, false, undefined, false @@ -604,7 +587,7 @@ describe('AIActionsService', () => { mockEditor.commands.search = vi.fn().mockReturnValue([{ from: 0, to: 4 }]); // Test with logging disabled - const actions1 = new AIActionsService(mockProvider, mockEditor, createContextProvider({primaryText: 'context'}), false); + const actions1 = new AIActionsService(mockProvider, mockEditor, () => 'context', false); const response1 = JSON.stringify({ success: true, results: [{ originalText: 'test', suggestedText: 'new' }] @@ -628,7 +611,7 @@ describe('AIActionsService', () => { mockProvider.getCompletion = vi.fn().mockResolvedValue(response); mockEditor.commands.search = vi.fn().mockReturnValue([]); - const actions = new AIActionsService(mockProvider, mockEditor, defaultContextProvider, false); + const actions = new AIActionsService(mockProvider, mockEditor, () => mockEditor.state.doc.textContent, false); const result = await actions.replace('replace text'); expect(result.results).toHaveLength(0); diff --git a/packages/ai/src/ai-actions-service.ts b/packages/ai/src/ai-actions-service.ts index f190858425..e41a923599 100644 --- a/packages/ai/src/ai-actions-service.ts +++ b/packages/ai/src/ai-actions-service.ts @@ -1,13 +1,4 @@ -import { - AIProvider, - Editor, - Result, - FoundMatch, - DocumentPosition, - AIMessage, - ContextWindow, - ContextScope, -} from './types'; +import {AIProvider, Editor, Result, FoundMatch, DocumentPosition, AIMessage} from './types'; import {EditorAdapter} from './editor-adapter'; import {validateInput, parseJSON} from './utils'; import { @@ -24,11 +15,13 @@ import { */ export class AIActionsService { private adapter: EditorAdapter; + private capturedContext: string | null = null; + private capturedSelectionBounds: { from: number; to: number } | null = null; constructor( private provider: AIProvider, private editor: Editor | null, - private contextProvider: (scope?: ContextScope) => ContextWindow, + private documentContextProvider: () => string, private enableLogging: boolean = false, private onStreamChunk?: (partialResult: string) => void, private streamPreference?: boolean, @@ -43,20 +36,43 @@ export class AIActionsService { } } - private getContext(scope?: ContextScope): ContextWindow { - if (!this.contextProvider) { - return {scope: 'document', primaryText: ''}; + /** + * Sets a captured context that will be used instead of calling the provider. + * This ensures the context (including selection) is captured before async operations. + */ + public setCapturedContext(context: string | null, selectionBounds?: { from: number; to: number } | null): void { + this.capturedContext = context; + this.capturedSelectionBounds = selectionBounds || null; + } + + /** + * Clears the captured context, reverting to using the provider function. + */ + public clearCapturedContext(): void { + this.capturedContext = null; + this.capturedSelectionBounds = null; + } + + private getDocumentContext(): string { + // If a context was captured synchronously, use it + if (this.capturedContext !== null) { + return this.capturedContext; + } + + // Otherwise, call the provider function + if (!this.documentContextProvider) { + return ''; } try { - return this.contextProvider(scope); + return this.documentContextProvider(); } catch (error) { if (this.enableLogging) { console.error( `Failed to retrieve document context: ${error instanceof Error ? error.message : 'Unknown error'}` ); } - return {scope: 'document', primaryText: ''}; + return ''; } } @@ -72,13 +88,13 @@ export class AIActionsService { throw new Error('Query cannot be empty'); } - const context = this.getContext(); + const documentContext = this.getDocumentContext(); - if (!context.primaryText?.trim()) { + if (!documentContext) { return {success: false, results: []}; } - const prompt = buildFindPrompt(query, context, findAll); + const prompt = buildFindPrompt(query, documentContext, findAll); const response = await this.runCompletion([ {role: 'system', content: SYSTEM_PROMPTS.SEARCH}, {role: 'user', content: prompt}, @@ -89,7 +105,7 @@ export class AIActionsService { if (!result.success || !result.results) { return result; } - result.results = this.adapter.findResults(result.results); + result.results = this.adapter.findResults(result.results, this.capturedSelectionBounds); return result; } @@ -168,14 +184,14 @@ export class AIActionsService { multiple: boolean, operationFn: (adapter: EditorAdapter, position: DocumentPosition, replacement: FoundMatch) => Promise ): Promise { - const context = this.getContext(); + const documentContext = this.getDocumentContext(); - if (!context.primaryText?.trim()) { + if (!documentContext) { return []; } // Get AI query - const prompt = buildReplacePrompt(query, context, multiple); + const prompt = buildReplacePrompt(query, documentContext, multiple); const response = await this.runCompletion([ {role: 'system', content: SYSTEM_PROMPTS.EDIT}, {role: 'user', content: prompt}, @@ -193,7 +209,7 @@ export class AIActionsService { return []; } - const searchResults = this.adapter.findResults(replacements); + const searchResults = this.adapter.findResults(replacements, this.capturedSelectionBounds); const match = searchResults?.[0]; for (const result of searchResults) { try { @@ -364,12 +380,12 @@ export class AIActionsService { * Generates a summary of the document. */ async summarize(query: string): Promise { - const context = this.getContext('document'); + const documentContext = this.getDocumentContext(); - if (!context.primaryText?.trim()) { + if (!documentContext) { return {results: [], success: false}; } - const prompt = buildSummaryPrompt(query, context); + const prompt = buildSummaryPrompt(query, documentContext); const useStreaming = this.streamPreference !== false; let streamedLength = 0; @@ -408,8 +424,8 @@ export class AIActionsService { return {success: false, results: []}; } - const context = this.getContext(); - const prompt = buildInsertContentPrompt(query, context); + const documentContext = this.getDocumentContext(); + const prompt = buildInsertContentPrompt(query, documentContext); const useStreaming = this.streamPreference !== false; let streamingInsertedLength = 0; diff --git a/packages/ai/src/ai-actions.test.ts b/packages/ai/src/ai-actions.test.ts index 9cf61a7353..596d015e85 100644 --- a/packages/ai/src/ai-actions.test.ts +++ b/packages/ai/src/ai-actions.test.ts @@ -1,19 +1,13 @@ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import {TextSelection} from 'prosemirror-state'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { AIActions } from './ai-actions'; -import {EditorAdapter} from './editor-adapter'; import type { AIProvider, AIActionsOptions, SuperDoc, Editor } from './types'; describe('AIActions', () => { let mockProvider: AIProvider; let mockEditor: Editor; let mockSuperdoc: SuperDoc; - let textSelectionSpy: any; - let scrollSpy: any; beforeEach(() => { - textSelectionSpy = vi.spyOn(TextSelection, 'create').mockReturnValue({} as any); - scrollSpy = vi.spyOn(EditorAdapter.prototype as any, 'scrollToPosition').mockImplementation(() => {}); mockProvider = { async *streamCompletion(messages, options) { yield 'chunk1'; @@ -27,36 +21,15 @@ describe('AIActions', () => { mockEditor = { state: { - selection: { - from: 0, - to: 6, - empty: false, - }, doc: { textContent: 'Sample document text', content: { size: 100 }, - textBetween: vi.fn(() => 'Sample'), - resolve: vi.fn((pos) => ({ - pos, - depth: 0, + resolve: vi.fn((pos) => ({ + pos, parent: { inlineContent: true }, - nodeAfter: null, - nodeBefore: null, - start: vi.fn(() => pos), - end: vi.fn(() => pos), - })), - descendants: vi.fn((cb) => { - cb( - { - isBlock: true, - textContent: 'Sample paragraph text', - type: {name: 'paragraph'}, - attrs: {}, - nodeSize: 20, - }, - 0, - ); - }), + min: vi.fn(() => pos), + max: vi.fn(() => pos) + })) }, tr: { setSelection: vi.fn().mockReturnThis(), @@ -91,11 +64,6 @@ describe('AIActions', () => { } as any; }); - afterEach(() => { - textSelectionSpy?.mockRestore(); - scrollSpy?.mockRestore(); - }); - describe('constructor', () => { it('should initialize with provider config', async () => { const options: AIActionsOptions = { @@ -293,56 +261,6 @@ describe('AIActions', () => { }); }); - describe('context window', () => { - it('should expose selection-scoped context by default', async () => { - const ai = new AIActions(mockSuperdoc, { - user: {displayName: 'AI Bot'}, - provider: mockProvider, - }); - await ai.waitUntilReady(); - - const contextWindow = ai.getContextWindow(); - expect(contextWindow.scope).toBe('selection'); - expect(contextWindow.primaryText).toBeTruthy(); - }); - - it('should strip internal context options before calling providers', async () => { - const ai = new AIActions(mockSuperdoc, { - user: {displayName: 'AI Bot'}, - provider: mockProvider, - }); - await ai.waitUntilReady(); - - const streamSpy = vi.spyOn(mockProvider, 'streamCompletion'); - await ai.streamCompletion('prompt', {contextScope: 'document', contextPaddingBlocks: 2}); - - expect(streamSpy).toHaveBeenCalledWith( - expect.any(Array), - expect.not.objectContaining({ - contextScope: expect.anything(), - }), - ); - }); - - it('should not forward context overrides for non-streaming completions', async () => { - const ai = new AIActions(mockSuperdoc, { - user: {displayName: 'AI Bot'}, - provider: mockProvider, - }); - await ai.waitUntilReady(); - - const completionSpy = vi.spyOn(mockProvider, 'getCompletion'); - await ai.getCompletion('prompt', {contextScope: 'document'}); - - expect(completionSpy).toHaveBeenCalledWith( - expect.any(Array), - expect.not.objectContaining({ - contextScope: expect.anything(), - }), - ); - }); - }); - describe('streamCompletion', () => { it('should stream completion chunks', async () => { const onStreamingStart = vi.fn(); diff --git a/packages/ai/src/ai-actions.ts b/packages/ai/src/ai-actions.ts index 0543ef1a49..9e0f87eb19 100644 --- a/packages/ai/src/ai-actions.ts +++ b/packages/ai/src/ai-actions.ts @@ -1,8 +1,5 @@ import type { CompletionOptions, - ContextScope, - ContextWindow, - ContextWindowConfig, Editor, Result, StreamOptions, @@ -14,8 +11,6 @@ import type { } from './types'; import {AIActionsService} from './ai-actions-service'; import {createAIProvider, isAIProvider} from './providers'; -import {EditorAdapter} from './editor-adapter'; -import {formatContextWindow} from './prompts'; /** * Primary entry point for SuperDoc AI capabilities. Wraps a SuperDoc instance, @@ -50,10 +45,6 @@ export class AIActions { private isReady = false; private initializationPromise: Promise | null = null; private readonly commands: AIActionsService; - private readonly contextWindowConfig: { - paddingBlocks: number; - maxChars: number; - }; public readonly action = { find: async (instruction: string) => { @@ -115,12 +106,6 @@ export class AIActions { provider: aiProvider, }; - const contextWindowDefaults: ContextWindowConfig | undefined = this.config.contextWindow; - this.contextWindowConfig = { - paddingBlocks: Math.max(0, contextWindowDefaults?.paddingBlocks ?? 1), - maxChars: Math.max(200, contextWindowDefaults?.maxChars ?? 2000), - }; - this.callbacks = { onReady, onStreamingStart, @@ -145,7 +130,7 @@ export class AIActions { this.commands = new AIActionsService( this.config.provider, editor, - (scope) => this.getContextWindow({scope}), + () => this.getDocumentContext(), this.config.enableLogging, (partial) => this.callbacks.onStreamingPartialResult?.({partialResult: partial}), streamResults, @@ -183,7 +168,35 @@ export class AIActions { } /** - * Executes an action with full callback lifecycle support + * Gets the current selection bounds if a selection exists. + * @private + * @returns Selection bounds {from, to} or null if no selection + */ + private getSelectionBounds(): { from: number; to: number } | null { + const editor = this.getEditor(); + if (!editor) { + return null; + } + + const state = editor.view?.state || editor.state; + if (!state || !state.selection) { + return null; + } + + const { selection } = state; + if (selection.empty) { + return null; + } + + return { + from: selection.from, + to: selection.to, + }; + } + + /** + * Executes an action with full callback lifecycle support. + * Captures the document context (including selection) synchronously before any async operations. * @private */ private async executeActionWithCallbacks( @@ -193,6 +206,13 @@ export class AIActions { if (!editor) { throw new Error('No active SuperDoc editor available for AI actions'); } + + // Capture context synchronously before any async operations + // This ensures the selection is locked in at the moment the action is called + const capturedContext = this.getDocumentContext(); + const selectionBounds = this.getSelectionBounds(); + this.commands.setCapturedContext(capturedContext, selectionBounds); + try { this.callbacks.onStreamingStart?.(); const result: T = await fn(); @@ -202,6 +222,9 @@ export class AIActions { } catch (error: Error | any) { this.handleError(error as Error); throw error; + } finally { + // Clear the captured context after the action completes + this.commands.clearCapturedContext(); } } @@ -257,11 +280,8 @@ export class AIActions { throw new Error('AIActions is not ready yet. Call waitUntilReady() first.'); } - const context = this.getContextWindow({ - scope: options?.contextScope, - paddingBlocks: options?.contextPaddingBlocks, - }); - const userContent = this.buildPromptWithContext(prompt, context); + const documentContext = this.getDocumentContext(); + const userContent = documentContext ? `${prompt}\n\nDocument context:\n${documentContext}` : prompt; const messages = [ {role: 'system' as const, content: this.config.systemPrompt || ''}, @@ -269,16 +289,11 @@ export class AIActions { ]; let accumulated = ''; - const providerOptions = options ? {...options} : undefined; - if (providerOptions) { - delete (providerOptions as Partial).contextScope; - delete (providerOptions as Partial).contextPaddingBlocks; - } try { this.callbacks.onStreamingStart?.(); - const stream = this.config.provider.streamCompletion(messages, providerOptions); + const stream = this.config.provider.streamCompletion(messages, options); for await (const chunk of stream) { accumulated += chunk; @@ -307,25 +322,16 @@ export class AIActions { throw new Error('AIActions is not ready yet. Call waitUntilReady() first.'); } - const context = this.getContextWindow({ - scope: options?.contextScope, - paddingBlocks: options?.contextPaddingBlocks, - }); - const userContent = this.buildPromptWithContext(prompt, context); + const documentContext = this.getDocumentContext(); + const userContent = documentContext ? `${prompt}\n\nDocument context:\n${documentContext}` : prompt; const messages = [ {role: 'system' as const, content: this.config.systemPrompt || ''}, {role: 'user' as const, content: userContent}, ]; - const providerOptions = options ? {...options} : undefined; - if (providerOptions) { - delete (providerOptions as Partial).contextScope; - delete (providerOptions as Partial).contextPaddingBlocks; - } - try { - return await this.config.provider.getCompletion(messages, providerOptions); + return await this.config.provider.getCompletion(messages, options); } catch (error) { this.handleError(error as Error); throw error; @@ -334,100 +340,35 @@ export class AIActions { /** * Retrieves the current document context for AI processing. - * Combines XML and plain text representations when available. + * Returns selected text if available, otherwise returns the full document. * - * @returns Document context string + * @returns Document context string (selected text if available, otherwise full document) */ public getDocumentContext(): string { - return this.getContextWindow({scope: 'document'}).primaryText; - } - - /** - * Returns a scoped context window summarizing the current selection and neighbors. - */ - public getContextWindow(options?: {scope?: ContextScope; paddingBlocks?: number}): ContextWindow { - const rawWindow = this.buildContextWindow(options); - return this.applyContextConstraints(rawWindow); - } - - private buildContextWindow(options?: {scope?: ContextScope; paddingBlocks?: number}): ContextWindow { const editor = this.getEditor(); - const scope = options?.scope; - if (!editor) { - return { - scope: scope ?? 'document', - primaryText: '', - }; + return ''; } - const adapter = new EditorAdapter(editor); - const paddingBlocks = this.resolvePaddingBlocks(options?.paddingBlocks); - - return adapter.getContextWindow(paddingBlocks, scope); - } - - private resolvePaddingBlocks(padding?: number): number { - if (typeof padding === 'number' && padding >= 0) { - return padding; + // Try to get state from view first (most up-to-date), then fall back to editor.state + const state = editor.view?.state || editor.state; + if (!state || !state.doc) { + return ''; } - return this.contextWindowConfig.paddingBlocks; - } - - private applyContextConstraints(context: ContextWindow): ContextWindow { - const clamp = (value?: string): string | undefined => { - if (!value) { - return value; - } - - const limit = this.contextWindowConfig.maxChars; - if (!limit || value.length <= limit) { - return value; + const { selection, doc } = state; + + // If there's a non-empty selection, return the selected text + if (selection && !selection.empty) { + const selectedText = doc.textBetween(selection.from, selection.to, ' ').trim(); + // Only return selected text if it's not empty (handles edge cases) + if (selectedText) { + return selectedText; } - - return `${value.slice(0, limit)}...`; - }; - - const selection = context.selection - ? { - ...context.selection, - text: clamp(context.selection.text) ?? '', - block: context.selection.block - ? { - ...context.selection.block, - text: clamp(context.selection.block.text) ?? '', - } - : undefined, - surroundingBlocks: (context.selection.surroundingBlocks || []).map((block) => ({ - ...block, - text: clamp(block.text) ?? '', - })), - } - : undefined; - - return { - ...context, - primaryText: clamp(context.primaryText) ?? '', - selection, - }; - } - - private buildPromptWithContext(prompt: string, context: ContextWindow): string { - const formattedContext = this.serializeContextWindow(context); - if (!formattedContext) { - return prompt; - } - - return `${prompt}\n\nContext window:\n${formattedContext}`; - } - - private serializeContextWindow(context: ContextWindow): string { - if (!context.primaryText?.trim()) { - return ''; } - return formatContextWindow(context); + // Otherwise, return the full document content + return doc.textContent?.trim() || ''; } /** diff --git a/packages/ai/src/editor-adapter.test.ts b/packages/ai/src/editor-adapter.test.ts index 41d734185a..6d3885b852 100644 --- a/packages/ai/src/editor-adapter.test.ts +++ b/packages/ai/src/editor-adapter.test.ts @@ -195,18 +195,14 @@ describe('EditorAdapter', () => { await mockAdapter.replaceText(0, 5, 'hello'); expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); - expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith({ - type: 'paragraph', - content: [{ - type: 'text', - text: 'hello', - marks: [ - { type: 'bold', attrs: {} }, - { type: 'textStyle', attrs: { fontSize: '14pt' } } - ] - }] + type: 'text', + text: 'hello', + marks: [ + { type: 'bold', attrs: {} }, + { type: 'textStyle', attrs: { fontSize: '14pt' } } + ] }); }); @@ -215,25 +211,19 @@ describe('EditorAdapter', () => { await mockAdapter.replaceText(0, 5, 'hello'); - expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); - expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); - expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith('hello'); }); }); describe('createTrackedChange', () => { - it('should create tracked change without marks', async () => { + it('should create tracked change with author', async () => { + mockEditor.commands.getSelectionMarks = vi.fn().mockReturnValue([]); - const changeId = await mockAdapter.createTrackedChange(0, 5, 'new'); + const changeId = await mockAdapter.createTrackedChange(0, 5, 'new'); expect(changeId).toMatch(/^tracked-change-/); expect(mockEditor.commands.enableTrackChanges).toHaveBeenCalled(); - expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); - expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); - expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); - expect(mockEditor.commands.insertContent).toHaveBeenCalledWith('new'); expect(mockEditor.commands.disableTrackChanges).toHaveBeenCalled(); }); @@ -246,17 +236,13 @@ describe('EditorAdapter', () => { await mockAdapter.createTrackedChange(0, 5, 'new'); - expect(mockEditor.commands.enableTrackChanges).toHaveBeenCalled(); - expect(mockEditor.commands.setTextSelection).toHaveBeenCalledWith({ from: 0, to: 5 }); - expect(mockEditor.commands.getSelectionMarks).toHaveBeenCalled(); - expect(mockEditor.commands.deleteSelection).toHaveBeenCalled(); expect(mockEditor.commands.insertContent).toHaveBeenCalledWith({ type: 'text', text: 'new', marks: [{ type: 'italic', attrs: {} }] }); - expect(mockEditor.commands.disableTrackChanges).toHaveBeenCalled(); }); + }); describe('createComment', () => { diff --git a/packages/ai/src/editor-adapter.ts b/packages/ai/src/editor-adapter.ts index a2385dd2f1..3e5a72379d 100644 --- a/packages/ai/src/editor-adapter.ts +++ b/packages/ai/src/editor-adapter.ts @@ -1,20 +1,7 @@ -import type { - ContextBlock, - ContextScope, - ContextWindow, - Editor, - FoundMatch, - MarkType, - SelectionContext, -} from './types'; -import { generateId } from './utils'; +import type { Editor, FoundMatch, MarkType } from './types'; +import {generateId} from "./utils"; import { TextSelection } from 'prosemirror-state'; -type TemplateNode = { - marks: MarkType[]; - length: number; -}; - /** * Adapter for SuperDoc editor operations * Encapsulates all editor-specific API calls @@ -23,7 +10,8 @@ export class EditorAdapter { constructor(private editor: Editor) {} // Search for string occurrences and resolve document positions - findResults(results: FoundMatch[]): FoundMatch[] { + // If selectionBounds is provided, only returns matches within the selected area + findResults(results: FoundMatch[], selectionBounds?: { from: number; to: number } | null): FoundMatch[] { if (!results?.length) { return []; } @@ -33,7 +21,7 @@ export class EditorAdapter { const text = match.originalText; const rawMatches = this.editor.commands?.search?.(text) ?? []; - const positions = rawMatches + let positions = rawMatches .map((match: { from?: number; to?: number}) => { const from = match.from; const to = match.to; @@ -42,7 +30,16 @@ export class EditorAdapter { } return { from, to }; }) - .filter((value: { from: number; to: number } | null) => value !== null); + .filter((value: { from: number; to: number } | null) => value !== null) as { from: number; to: number }[]; + + // Filter positions to only include those within the selection bounds if provided + if (selectionBounds) { + positions = positions.filter((pos) => { + // Check if the match overlaps with or is within the selection bounds + // A match is within bounds if it starts at or after selection.from and ends at or before selection.to + return pos.from >= selectionBounds.from && pos.to <= selectionBounds.to; + }); + } return { ...match, @@ -72,17 +69,13 @@ export class EditorAdapter { this.editor.commands.deleteSelection(); if (marks.length > 0) { this.editor.commands.insertContent({ - type: 'paragraph', - content: [{ - type: 'text', - text: suggestedText, - marks: marks.map((mark: MarkType) => ({ - type: mark.type.name, - attrs: mark.attrs, - })), - }], - }); - + type: 'text', + text: suggestedText, + marks: marks.map((mark: MarkType) => ({ + type: mark.type.name, + attrs: mark.attrs, + })), + }); } else { this.editor.commands.insertContent(suggestedText); } @@ -150,143 +143,4 @@ export class EditorAdapter { })), }); } - - - getContextWindow(paddingBlocks: number = 1, scopeOverride?: ContextScope): ContextWindow { - const state = this.editor?.state; - const doc = state?.doc; - - if (!state || !doc) { - return { - scope: 'document', - primaryText: '', - }; - } - - const selection = state.selection; - const blocks = this.collectBlocks(); - const selectionInfo = this.buildSelectionContext(blocks, paddingBlocks); - const documentText = doc.textContent?.trim() ?? ''; - - const hasSelection = Boolean(selection && !selection.empty && selectionInfo?.text !== undefined); - const derivedPrimaryText = hasSelection - ? (selectionInfo?.text ?? '') - : (selectionInfo?.block?.text ?? documentText); - const primaryText = scopeOverride === 'document' ? documentText : derivedPrimaryText; - - const scope: ContextScope = scopeOverride - ? scopeOverride - : hasSelection - ? 'selection' - : selectionInfo?.block - ? 'block' - : 'document'; - - const metadata: Record = { - documentId: (this.editor.options as any)?.documentId, - }; - - return { - scope, - primaryText: primaryText ?? '', - selection: scopeOverride === 'document' ? undefined : (selectionInfo ?? undefined), - documentStats: { - wordCount: documentText ? documentText.split(/\s+/).filter(Boolean).length : 0, - charCount: documentText.length, - }, - metadata, - }; - } - - private collectBlocks(): ContextBlock[] { - const {doc} = this.editor.state; - const blocks: ContextBlock[] = []; - doc.descendants((node: any, pos: any) => { - if (!node.isBlock) { - return true; - } - - const text = node.textContent?.trim() ?? ''; - - blocks.push({ - type: node.type.name, - text, - from: pos, - to: pos + node.nodeSize, - attrs: node.attrs, - headingLevel: typeof node.attrs?.level === 'number' ? node.attrs.level : undefined, - title: typeof node.attrs?.title === 'string' ? node.attrs.title : undefined, - }); - - return true; - }); - - return blocks; - } - - private buildSelectionContext( - blocks: ContextBlock[], - paddingBlocks: number - ): SelectionContext | null { - if (!this.editor.state?.selection) { - return null; - } - - const selection = this.editor.state.selection; - const doc = this.editor.state.doc; - const text = - selection && !selection.empty ? doc.textBetween(selection.from, selection.to, '\n\n', '\n\n') : ''; - - const blockIndex = this.findBlockIndex(blocks, selection.from); - const block = blockIndex >= 0 ? blocks[blockIndex] : undefined; - const surroundingBlocks = blockIndex >= 0 ? this.getNeighborBlocks(blocks, blockIndex, paddingBlocks) : []; - const activeMarks = (this.editor.commands.getSelectionMarks?.() as MarkType[]) || []; - - const metadata: Record = {}; - if (block?.title) { - metadata.clauseTitle = block.title; - } - if (typeof block?.headingLevel === 'number') { - metadata.headingLevel = block.headingLevel; - } - - if (!text && !block) { - return null; - } - - return { - from: selection.from, - to: selection.to, - text, - normalizedText: text?.trim?.(), - block, - surroundingBlocks, - activeMarks: activeMarks.map((mark) => ({ - type: mark.type.name, - attrs: mark.attrs, - })), - metadata: Object.keys(metadata).length ? metadata : undefined, - }; - } - - private findBlockIndex(blocks: ContextBlock[], position: number): number { - return blocks.findIndex((block) => position >= block.from && position < block.to); - } - - private getNeighborBlocks(blocks: ContextBlock[], index: number, paddingBlocks: number): ContextBlock[] { - if (index < 0) { - return []; - } - - const neighbors: ContextBlock[] = []; - for (let i = Math.max(0, index - paddingBlocks); i <= Math.min(blocks.length - 1, index + paddingBlocks); i++) { - if (i === index) { - continue; - } - - neighbors.push(blocks[i]); - } - - return neighbors; - } } diff --git a/packages/ai/src/prompts.ts b/packages/ai/src/prompts.ts index 348768144a..1a2def7760 100644 --- a/packages/ai/src/prompts.ts +++ b/packages/ai/src/prompts.ts @@ -1,5 +1,3 @@ -import type {ContextBlock, ContextWindow} from './types'; - /** * AI prompt templates for document operations */ @@ -11,50 +9,12 @@ export const SYSTEM_PROMPTS = { CONTENT_GENERATION: 'You are a document content generation assistant. Always respond with valid JSON.', } as const; -const describeBlock = (block: ContextBlock): string => { - const descriptor: string[] = [block.type]; - if (block.title) { - descriptor.push(`"${block.title}"`); - } - if (typeof block.headingLevel === 'number') { - descriptor.push(`(level ${block.headingLevel})`); - } - return `${descriptor.join(' ')}:\n${block.text}`; -}; - -export const formatContextWindow = (context: ContextWindow): string => { - const segments: string[] = [`Scope: ${context.scope}`]; - - if (context.selection?.text) { - segments.push(`Selected text:\n${context.selection.text}`); - } - - if (context.selection?.block?.text) { - segments.push(`Active block:\n${describeBlock(context.selection.block)}`); - } - - if (context.selection?.surroundingBlocks?.length) { - const nearby = context.selection.surroundingBlocks.map((block) => describeBlock(block)).join('\n---\n'); - segments.push(`Surrounding blocks:\n${nearby}`); - } - - if (!context.selection?.text && context.primaryText) { - segments.push(`Primary text:\n${context.primaryText}`); - } - - if (context.metadata?.documentId) { - segments.push(`Document ID: ${context.metadata.documentId}`); - } - - return segments.filter(Boolean).join('\n\n'); -}; - -export const buildFindPrompt = (query: string, context: ContextWindow, findAll: boolean): string => { +export const buildFindPrompt = (query: string, documentContext: string, findAll: boolean): string => { const scope = findAll ? 'ALL occurrences' : 'FIRST occurrence ONLY'; return `apply this query for original text return the EXACT text from the doc no title or added text: ${query}, ${scope} - Context window: - ${formatContextWindow(context)} + Document context: + ${documentContext} Respond with JSON: { @@ -66,7 +26,7 @@ export const buildFindPrompt = (query: string, context: ContextWindow, findAll: }`; }; -export const buildReplacePrompt = (query: string, context: ContextWindow, replaceAll: boolean): string => { +export const buildReplacePrompt = (query: string, documentContext: string, replaceAll: boolean): string => { const scope = replaceAll ? 'ALL occurrences' : 'FIRST occurrence ONLY'; @@ -74,8 +34,8 @@ export const buildReplacePrompt = (query: string, context: ContextWindow, replac const finalQuery = `apply this query: ${query} if find and replace query then Find and replace the EXACT text of ${scope}`; return `${finalQuery} - Context window: - ${formatContextWindow(context)} + Document context: + ${documentContext} Respond with JSON: { @@ -87,10 +47,10 @@ export const buildReplacePrompt = (query: string, context: ContextWindow, replac }`; }; -export const buildSummaryPrompt = (query: string, context: ContextWindow): string => { +export const buildSummaryPrompt = (query: string, documentContext: string): string => { return `${query} - Context window: - ${formatContextWindow(context)} + Document context: + ${documentContext} Respond with JSON: { @@ -102,10 +62,9 @@ export const buildSummaryPrompt = (query: string, context: ContextWindow): strin }`; }; -export const buildInsertContentPrompt = (query: string, context: ContextWindow): string => { +export const buildInsertContentPrompt = (query: string, documentContext?: string): string => { return `${query} - Context window: - ${formatContextWindow(context)} + ${documentContext ? `Current document:\n${documentContext}\n` : ''} Respond with JSON: { "success": boolean, "results": [ { "suggestedText": string, diff --git a/packages/ai/src/types.ts b/packages/ai/src/types.ts index 7812eac4cc..caad28469b 100644 --- a/packages/ai/src/types.ts +++ b/packages/ai/src/types.ts @@ -6,45 +6,6 @@ export type NodeType = Node; export type Editor = InstanceType; export type SuperDocInstance = typeof SuperDoc | SuperDoc; -export type ContextScope = 'selection' | 'block' | 'document'; - -export type ContextBlock = { - type: string; - text: string; - from: number; - to: number; - attrs?: Record; - headingLevel?: number; - title?: string; -}; - -export type SelectionContext = { - from: number; - to: number; - text: string; - normalizedText?: string; - block?: ContextBlock; - surroundingBlocks: ContextBlock[]; - activeMarks?: {type: string; attrs?: Record}[]; - metadata?: Record; -}; - -export type ContextWindow = { - scope: ContextScope; - primaryText: string; - selection?: SelectionContext; - documentStats?: { - wordCount: number; - charCount: number; - }; - metadata?: Record; -}; - -export type ContextWindowConfig = { - paddingBlocks?: number; - maxChars?: number; -}; - /** * Represents a position range in the document */ @@ -100,10 +61,6 @@ export type StreamOptions = { documentId?: string; /** Force streaming (true) or disable it (false). Defaults to true when supported. */ stream?: boolean; - /** Override the context scope used when building prompts */ - contextScope?: ContextScope; - /** Override the number of surrounding blocks included in the context */ - contextPaddingBlocks?: number; } /** @@ -147,8 +104,6 @@ export type AIActionsConfig = { systemPrompt?: string; /** Enable debug logging */ enableLogging?: boolean; - /** Context window configuration */ - contextWindow?: ContextWindowConfig; } /**