diff --git a/packages/pi/src/appservice.test.ts b/packages/pi/src/appservice.test.ts index ce95410..3d8b0c9 100644 --- a/packages/pi/src/appservice.test.ts +++ b/packages/pi/src/appservice.test.ts @@ -58,19 +58,18 @@ describe("PicklePiAgent streaming", () => { type: "m.room.message", }); - expect(client.beeper.streams.create).toHaveBeenCalledTimes(1); - expect(client.beeper.streams.register).toHaveBeenCalledTimes(1); - expect(client.beeper.streams.publish.mock.calls.map(([options]) => delta(options).part.type)).toEqual([ + expect(client.beeper.streams.startMessage).toHaveBeenCalledTimes(1); + expect(client.beeper.streams.publishPart.mock.calls.map(([options]) => options.part.type)).toEqual([ "start", "text-start", "text-delta", "text-end", "finish", ]); - expect(client.messages.edit).toHaveBeenCalledWith(expect.objectContaining({ + expect(client.beeper.streams.finalizeMessage).toHaveBeenCalledWith(expect.objectContaining({ eventId: "$target", roomId: "!room:example", - text: "hello", + body: "hello", })); }); }); @@ -107,9 +106,9 @@ function createClient() { const client = { beeper: { streams: { - create: vi.fn(async () => ({ descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example" } })), - publish: vi.fn(async () => undefined), - register: vi.fn(async () => undefined), + finalizeMessage: vi.fn(async () => ({ eventId: "$target", raw: {}, replacementEventId: "$edit", roomId: "!room:example" })), + publishPart: vi.fn(async () => undefined), + startMessage: vi.fn(async () => ({ descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example" }, eventId: "$target", roomId: "!room:example" })), }, }, close: vi.fn(async () => undefined), @@ -124,11 +123,3 @@ function createClient() { }; return client as unknown as MatrixClient & typeof client; } - -function delta(options: { content?: Record }): Record { - const deltas = options.content?.["com.beeper.llm.deltas"]; - if (!Array.isArray(deltas)) throw new Error("missing com.beeper.llm.deltas"); - const [first] = deltas; - if (!first || typeof first !== "object") throw new Error("missing stream delta"); - return first as Record; -} diff --git a/packages/pi/src/beeper-stream.test.ts b/packages/pi/src/beeper-stream.test.ts index f0df2f7..ee5cadb 100644 --- a/packages/pi/src/beeper-stream.test.ts +++ b/packages/pi/src/beeper-stream.test.ts @@ -3,140 +3,63 @@ import { describe, expect, it, vi } from "vitest"; import { BeeperStreamPublisher } from "./beeper-stream"; describe("Beeper stream publisher", () => { - it("creates a target message and registers it with a Beeper stream", async () => { - const { client, create, register, send } = createClient(); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_1" }); + it("starts one native stream message and publishes UI parts through native transport", async () => { + const { client, publishPart, startMessage } = createClient(); + const subscribers = [{ deviceId: "DESKTOP", userId: "@alice:example.com" }]; + const publisher = new BeeperStreamPublisher({ + client, + initialMessageMetadata: { model: "test" }, + roomId: "!room:example.com", + subscribers, + turnId: "turn_1", + userId: "@bot:example.com", + }); await expect(publisher.start()).resolves.toEqual({ descriptor: streamDescriptor, eventId: "$target", turnId: "turn_1", }); + await publisher.publish({ id: "text_turn_1", type: "text-start" }); + await publisher.publish({ delta: "hello", id: "text_turn_1", type: "text-delta" }); - expect(create).toHaveBeenCalledWith({ - roomId: "!room:example.com", - streamType: "com.beeper.llm", - }); - expect(send).toHaveBeenCalledWith({ + expect(startMessage).toHaveBeenCalledTimes(1); + expect(startMessage).toHaveBeenCalledWith({ content: { body: "...", "com.beeper.ai": { id: "turn_1", - metadata: { turn_id: "turn_1" }, + metadata: { model: "test", turn_id: "turn_1" }, parts: [], role: "assistant", }, - "com.beeper.stream": streamDescriptor, msgtype: "m.text", }, - messageType: "m.text", - roomId: "!room:example.com", - text: "...", - }); - expect(register).toHaveBeenCalledWith({ - descriptor: streamDescriptor, - eventId: "$target", roomId: "!room:example.com", + streamType: "com.beeper.llm", + subscribers, + userId: "@bot:example.com", }); - }); - - it("reuses an existing target message stream descriptor", async () => { - const { client, create, get, register, send } = createClient(); - const publisher = new BeeperStreamPublisher({ - client, - roomId: "!room:example.com", - targetEventId: "$existing", - turnId: "turn_reuse", - }); - - await expect(publisher.start()).resolves.toEqual({ - descriptor: streamDescriptor, - eventId: "$existing", - turnId: "turn_reuse", - }); - - expect(get).toHaveBeenCalledWith({ eventId: "$existing", roomId: "!room:example.com" }); - expect(create).not.toHaveBeenCalled(); - expect(send).not.toHaveBeenCalled(); - expect(register).not.toHaveBeenCalled(); - }); - - it("publishes callback chunks as monotonic com.beeper.llm.deltas envelopes", async () => { - const { client, publish } = createClient(); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_2" }); - - await publisher.start(); - await publisher.publish({ id: "text_turn_2", type: "text-start" }); - await publisher.publish({ delta: "hello", id: "text_turn_2", type: "text-delta" }); - - expect(publish).toHaveBeenCalledTimes(3); - expect(publish.mock.calls.map(([options]) => delta(options).seq)).toEqual([1, 2, 3]); - expect(publish.mock.calls.map(([options]) => delta(options).part)).toEqual([ + expect(publishPart.mock.calls.map(([options]) => options.part)).toEqual([ { - messageId: "turn_2", - messageMetadata: { turn_id: "turn_2" }, + messageId: "turn_1", + messageMetadata: { model: "test", turn_id: "turn_1" }, type: "start", }, - { id: "text_turn_2", type: "text-start" }, - { delta: "hello", id: "text_turn_2", type: "text-delta" }, + { id: "text_turn_1", type: "text-start" }, + { delta: "hello", id: "text_turn_1", type: "text-delta" }, ]); - for (const [options] of publish.mock.calls) { + for (const [options] of publishPart.mock.calls) { expect(options).toMatchObject({ - content: { - "com.beeper.llm.deltas": [ - { - "m.relates_to": { event_id: "$target", rel_type: "m.reference" }, - target_event: "$target", - turn_id: "turn_2", - }, - ], - }, eventId: "$target", roomId: "!room:example.com", + turnId: "turn_1", }); } }); - it("registers a local subscriber device with the stream", async () => { - const { client, publish, register } = createClient(); - const subscribers = [{ deviceId: "DESKTOP", userId: "@alice:example.com" }]; - const publisher = new BeeperStreamPublisher({ - client, - roomId: "!room:example.com", - subscribers, - turnId: "turn_direct", - }); - - await publisher.start(); - await publisher.publish({ delta: "hello", id: "text_turn_direct", type: "text-delta" }); - - expect(register).toHaveBeenCalledWith({ - descriptor: streamDescriptor, - eventId: "$target", - roomId: "!room:example.com", - subscribers, - }); - expect(publish.mock.calls.map(([options]) => delta(options).seq)).toEqual([1, 2]); - }); - - it("does not mutate final content or sequence when publish fails", async () => { - const { client, edit, publish } = createClient(); - publish.mockResolvedValueOnce(undefined).mockRejectedValueOnce(new Error("network down")); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_retry" }); - - await publisher.start(); - await expect(publisher.publish({ delta: "lost", id: "text_turn_retry", type: "text-delta" })).rejects.toThrow("network down"); - await publisher.publish({ delta: "ok", id: "text_turn_retry", type: "text-delta" }); - await publisher.finalize({ body: "ok" }); - - expect(delta(publish.mock.calls[1]![0]).seq).toBe(2); - expect(delta(publish.mock.calls[2]![0]).seq).toBe(2); - expect(delta(publish.mock.calls[3]![0]).seq).toBe(3); - expect(edit.mock.calls[0]![0].content.body).toBe("ok"); - }); - - it("serializes concurrent publishes through one stream target and monotonic sequence", async () => { - const { client, create, publish, register, send } = createClient(); + it("serializes concurrent publishes through one stream target", async () => { + const { client, publishPart, startMessage } = createClient(); const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_concurrent" }); await Promise.all([ @@ -145,11 +68,8 @@ describe("Beeper stream publisher", () => { publisher.publish({ delta: "b", id: "text_turn_concurrent", type: "text-delta" }), ]); - expect(create).toHaveBeenCalledTimes(1); - expect(send).toHaveBeenCalledTimes(1); - expect(register).toHaveBeenCalledTimes(1); - expect(publish.mock.calls.map(([options]) => delta(options).seq)).toEqual([1, 2, 3, 4]); - expect(publish.mock.calls.map(([options]) => delta(options).part.type)).toEqual([ + expect(startMessage).toHaveBeenCalledTimes(1); + expect(publishPart.mock.calls.map(([options]) => options.part.type)).toEqual([ "start", "text-start", "text-delta", @@ -158,26 +78,27 @@ describe("Beeper stream publisher", () => { }); it("continues the publish queue after a failed publish", async () => { - const { client, publish } = createClient(); - publish.mockResolvedValueOnce(undefined).mockRejectedValueOnce(new Error("network down")); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_queue_retry" }); + const { client, publishPart } = createClient(); + publishPart.mockResolvedValueOnce(undefined).mockRejectedValueOnce(new Error("network down")); + const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_retry" }); - await expect(publisher.publish({ id: "text_turn_queue_retry", type: "text-start" })).rejects.toThrow("network down"); - await publisher.publish({ delta: "ok", id: "text_turn_queue_retry", type: "text-delta" }); + await expect(publisher.publish({ id: "text_turn_retry", type: "text-start" })).rejects.toThrow("network down"); + await publisher.publish({ delta: "ok", id: "text_turn_retry", type: "text-delta" }); - expect(delta(publish.mock.calls[1]![0]).seq).toBe(2); - expect(delta(publish.mock.calls[2]![0]).seq).toBe(2); + expect(publishPart.mock.calls.map(([options]) => options.part)).toEqual([ + { messageId: "turn_retry", messageMetadata: { turn_id: "turn_retry" }, type: "start" }, + { id: "text_turn_retry", type: "text-start" }, + { delta: "ok", id: "text_turn_retry", type: "text-delta" }, + ]); }); - it("finalizes by publishing finish and editing com.beeper.ai while clearing the stream", async () => { - const { client, edit, publish } = createClient(); + it("finalizes by publishing a terminal part and asking native transport to clear the stream", async () => { + const { client, finalizeMessage, publishPart } = createClient(); const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_3" }); - await publisher.start(); await publisher.publish({ id: "text_turn_3", type: "text-start" }); await publisher.publish({ delta: "done", id: "text_turn_3", type: "text-delta" }); - await publisher.publish({ id: "text_turn_3", type: "text-end" }); - await publisher.finalize({ + const result = await publisher.finalize({ body: "done", message: { id: "turn_3", @@ -187,13 +108,13 @@ describe("Beeper stream publisher", () => { }, }); - expect(delta(publish.mock.calls.at(-1)![0]).part).toEqual({ + expect(publishPart.mock.calls.at(-1)![0].part).toEqual({ finishReason: "stop", messageMetadata: { finish_reason: "stop", turn_id: "turn_3" }, type: "finish", }); - expect(delta(publish.mock.calls.at(-1)![0]).seq).toBe(5); - expect(edit).toHaveBeenCalledWith({ + expect(finalizeMessage).toHaveBeenCalledWith({ + body: "done", content: { body: "done", "com.beeper.ai": { @@ -202,18 +123,23 @@ describe("Beeper stream publisher", () => { parts: [{ state: "done", text: "done", type: "text" }], role: "assistant", }, - "com.beeper.stream": null, msgtype: "m.text", }, eventId: "$target", - messageType: "m.text", roomId: "!room:example.com", - text: "done", topLevelContent: { "com.beeper.dont_render_edited": true, - "com.beeper.stream": null, }, }); + expect(result).toEqual({ + eventId: "$target", + raw: { + logicalEventId: "$target", + raw: {}, + replacementEventId: "$edit", + }, + roomId: "!room:example.com", + }); }); it("publishes terminal error and abort parts without finalizing the message", async () => { @@ -224,14 +150,13 @@ describe("Beeper stream publisher", () => { turnId: "turn_error", }); - await errorPublisher.start(); await errorPublisher.error(new Error("tool failed")); - expect(delta(errored.publish.mock.calls.at(-1)![0]).part).toEqual({ + expect(errored.publishPart.mock.calls.at(-1)![0].part).toEqual({ errorText: "tool failed", type: "error", }); - expect(errored.edit).not.toHaveBeenCalled(); + expect(errored.finalizeMessage).not.toHaveBeenCalled(); const aborted = createClient(); const abortPublisher = new BeeperStreamPublisher({ @@ -240,22 +165,20 @@ describe("Beeper stream publisher", () => { turnId: "turn_abort", }); - await abortPublisher.start(); await abortPublisher.abort("user cancelled"); - expect(delta(aborted.publish.mock.calls.at(-1)![0]).part).toEqual({ + expect(aborted.publishPart.mock.calls.at(-1)![0].part).toEqual({ reason: "user cancelled", type: "abort", }); - expect(aborted.edit).not.toHaveBeenCalled(); + expect(aborted.finalizeMessage).not.toHaveBeenCalled(); }); it("compacts oversized final Matrix content without dropping text or tool calls", async () => { - const { client, edit } = createClient(); + const { client, finalizeMessage } = createClient(); const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_big" }); const largeOutput = "x".repeat(70 * 1024); - await publisher.start(); await publisher.finalize({ body: "final answer", message: { @@ -274,7 +197,7 @@ describe("Beeper stream publisher", () => { }, }); - const content = edit.mock.calls[0]![0].content; + const content = finalizeMessage.mock.calls[0]![0].content; const ai = content["com.beeper.ai"] as Record; expect(Buffer.byteLength(JSON.stringify(content))).toBeLessThanOrEqual(60 * 1024); expect(content.body).toBe("final answer"); @@ -288,65 +211,16 @@ describe("Beeper stream publisher", () => { ]); }); - it("uses one global text budget when compacting final Matrix content", async () => { - const { client, edit } = createClient(); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_global_budget" }); - const text = "x".repeat(45 * 1024); - - await publisher.start(); - await publisher.finalize({ - body: text, - message: { - id: "turn_global_budget", - metadata: { turn_id: "turn_global_budget" }, - parts: [ - { state: "done", text, type: "text" }, - { state: "done", text, type: "text" }, - ], - role: "assistant", - }, - }); - - const content = edit.mock.calls[0]![0].content; - const ai = content["com.beeper.ai"] as Record; - expect(Buffer.byteLength(JSON.stringify(content))).toBeLessThanOrEqual(60 * 1024); - expect(`${content.body}${ai.parts.map((part: any) => part.text ?? "").join("")}`).toContain("Matrix event compacted"); - }); - - it("updates an existing fallback tool part when the tool name arrives late", async () => { - const { client, edit } = createClient(); - const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_late_tool" }); - - await publisher.start(); - await publisher.publish({ dynamic: true, output: "running", toolCallId: "call_1", type: "tool-output-available" }); - await publisher.publish({ - dynamic: true, - input: { cmd: "date" }, - toolCallId: "call_1", - toolName: "exec", - type: "tool-input-available", - }); - await publisher.finalize({ body: "done" }); - - const ai = edit.mock.calls[0]![0].content["com.beeper.ai"] as Record; - expect(ai.parts[0]).toMatchObject({ - toolCallId: "call_1", - toolName: "exec", - type: "dynamic-tool", - }); - }); - it("preserves abort reasons in final terminal metadata", async () => { - const { client, edit } = createClient(); + const { client, finalizeMessage } = createClient(); const publisher = new BeeperStreamPublisher({ client, roomId: "!room:example.com", turnId: "turn_abort_final" }); - await publisher.start(); await publisher.finalize({ body: "cancelled", terminalPart: { reason: "user cancelled", type: "abort" }, }); - const ai = edit.mock.calls[0]![0].content["com.beeper.ai"] as Record; + const ai = finalizeMessage.mock.calls[0]![0].content["com.beeper.ai"] as Record; expect(ai.metadata.beeper_terminal_state).toEqual({ reason: "user cancelled", type: "abort", @@ -361,50 +235,23 @@ const streamDescriptor = { }; function createClient() { - const create = vi.fn(async () => ({ descriptor: streamDescriptor })); - const register = vi.fn(async () => undefined); - const publish = vi.fn(async () => undefined); - const send = vi.fn(async () => ({ eventId: "$target", raw: {}, roomId: "!room:example.com" })); - const edit = vi.fn(async () => ({ eventId: "$edit", raw: {}, roomId: "!room:example.com" })); - const get = vi.fn(async () => ({ - message: { - attachments: [], - class: "message", - content: { - "com.beeper.stream": streamDescriptor, - }, - edited: false, - encrypted: false, - eventId: "$existing", - kind: "message", - raw: {}, - roomId: "!room:example.com", - text: "...", - type: "m.room.message", - }, + const startMessage = vi.fn(async () => ({ descriptor: streamDescriptor, eventId: "$target", roomId: "!room:example.com" })); + const publishPart = vi.fn(async () => undefined); + const finalizeMessage = vi.fn(async () => ({ + eventId: "$target", + raw: {}, + replacementEventId: "$edit", + roomId: "!room:example.com", })); const client = { beeper: { streams: { - create, - publish, - register, + finalizeMessage, + publishPart, + startMessage, }, }, - messages: { - edit, - get, - send, - }, } as unknown as MatrixClient; - return { client, create, edit, get, publish, register, send }; -} - -function delta(options: { content?: Record }): Record { - const deltas = options.content?.["com.beeper.llm.deltas"]; - if (!Array.isArray(deltas)) throw new Error("missing com.beeper.llm.deltas"); - const [first] = deltas; - if (!first || typeof first !== "object") throw new Error("missing stream delta"); - return first as Record; + return { client, finalizeMessage, publishPart, startMessage }; } diff --git a/packages/pi/src/beeper-stream.ts b/packages/pi/src/beeper-stream.ts index 229b2e1..4881225 100644 --- a/packages/pi/src/beeper-stream.ts +++ b/packages/pi/src/beeper-stream.ts @@ -1,4 +1,4 @@ -import type { MatrixBeeper, MatrixMessages, SentEvent } from "@beeper/pickle"; +import type { MatrixBeeper, SentEvent } from "@beeper/pickle"; import { applyFinalMessagePart, compactFinalContent, @@ -12,7 +12,6 @@ import { createTurnId, type BeeperUIMessageChunk } from "./stream-map"; export interface BeeperStreamPublisherClient { beeper: MatrixBeeper; - messages: Pick; } export interface BeeperStreamSubscriber { @@ -21,13 +20,14 @@ export interface BeeperStreamSubscriber { } export interface CreateBeeperStreamPublisherOptions { + agentId?: string; client: BeeperStreamPublisherClient; initialMessageMetadata?: Record; roomId: string; subscribers?: BeeperStreamSubscriber[]; - targetEventId?: string; threadRoot?: string; turnId?: string; + userId?: string; } export interface BeeperStreamStartResult { @@ -49,24 +49,26 @@ export class BeeperStreamPublisher { readonly roomId: string; readonly turnId: string; #accumulator: BeeperFinalMessageAccumulator; + #agentId: string | undefined; #client: BeeperStreamPublisherClient; #descriptor: Record | undefined; #finalized = false; #initialMessageMetadata: Record; #queue = new SerialQueue(); - #seq = 1; #subscribers: BeeperStreamSubscriber[]; #targetEventId: string | undefined; #threadRoot: string | undefined; + #userId: string | undefined; constructor(options: CreateBeeperStreamPublisherOptions) { + this.#agentId = options.agentId; this.#client = options.client; this.#initialMessageMetadata = options.initialMessageMetadata ?? {}; this.roomId = options.roomId; this.turnId = options.turnId ?? createTurnId(); this.#subscribers = options.subscribers ?? []; - this.#targetEventId = options.targetEventId; this.#threadRoot = options.threadRoot; + this.#userId = options.userId; this.#accumulator = createFinalMessageAccumulator(this.turnId); } @@ -120,30 +122,28 @@ export class BeeperStreamPublisher { aiMessage: finalAIMessage, body: finalText, }); - const replacement = await this.#client.messages.edit({ + const replacement = await this.#client.beeper.streams.finalizeMessage({ + body: finalContent.body || "...", content: { body: finalContent.body || "...", "com.beeper.ai": finalContent.aiMessage, - "com.beeper.stream": null, msgtype: "m.text", }, eventId: targetEventId, - messageType: "m.text", roomId: this.roomId, - text: finalContent.body || "...", topLevelContent: { "com.beeper.dont_render_edited": true, - "com.beeper.stream": null, }, + ...(this.#userId ? { userId: this.#userId } : {}), }); this.#finalized = true; return { - ...replacement, eventId: targetEventId, + roomId: replacement.roomId, raw: { logicalEventId: targetEventId, raw: replacement.raw, - replacementEventId: replacement.eventId, + replacementEventId: replacement.replacementEventId, }, }; }); @@ -153,72 +153,36 @@ export class BeeperStreamPublisher { if (this.#targetEventId && this.#descriptor) { return { descriptor: this.#descriptor, eventId: this.#targetEventId, turnId: this.turnId }; } - if (this.#targetEventId) { - const { message } = await this.#client.messages.get({ eventId: this.#targetEventId, roomId: this.roomId }); - const descriptor = message?.content["com.beeper.stream"]; - if (!isRecord(descriptor)) { - throw new Error(`Target message ${this.#targetEventId} does not contain a Beeper stream descriptor`); - } - this.#descriptor = descriptor; - return { descriptor, eventId: this.#targetEventId, turnId: this.turnId }; - } - const stream = await this.#client.beeper.streams.create({ roomId: this.roomId, streamType: "com.beeper.llm" }); - this.#descriptor = stream.descriptor; - const target = await this.#client.messages.send({ + const target = await this.#client.beeper.streams.startMessage({ content: { body: "...", "com.beeper.ai": { id: this.turnId, metadata: { turn_id: this.turnId, ...this.#initialMessageMetadata }, parts: [], role: "assistant" }, - "com.beeper.stream": stream.descriptor, msgtype: "m.text", }, - messageType: "m.text", - roomId: this.roomId, - text: "...", - ...(this.#threadRoot ? { threadRoot: this.#threadRoot } : {}), - }); - this.#targetEventId = target.eventId; - await this.#client.beeper.streams.register({ - descriptor: stream.descriptor, - eventId: target.eventId, roomId: this.roomId, + streamType: "com.beeper.llm", ...(this.#subscribers.length > 0 ? { subscribers: this.#subscribers } : {}), + ...(this.#threadRoot ? { threadRootEventId: this.#threadRoot } : {}), + ...(this.#userId ? { userId: this.#userId } : {}), }); + this.#descriptor = target.descriptor; + this.#targetEventId = target.eventId; await this.#publishPart(target.eventId, { messageId: this.turnId, messageMetadata: { turn_id: this.turnId, ...this.#initialMessageMetadata }, type: "start" }); - return { descriptor: stream.descriptor, eventId: target.eventId, turnId: this.turnId }; + return { descriptor: target.descriptor, eventId: target.eventId, turnId: this.turnId }; } async #publishPart(targetEventId: string, part: BeeperUIMessageChunk): Promise { - const descriptorType = descriptorTypeOf(this.#descriptor); - const seq = this.#seq; - const content = { - [`${descriptorType}.deltas`]: [ - { - "m.relates_to": { event_id: targetEventId, rel_type: "m.reference" }, - part, - seq, - target_event: targetEventId, - turn_id: this.turnId, - }, - ], - }; - await this.#client.beeper.streams.publish({ - content, + await this.#client.beeper.streams.publishPart({ + ...(this.#agentId ? { agentId: this.#agentId } : {}), eventId: targetEventId, + part, roomId: this.roomId, + turnId: this.turnId, }); - this.#seq = seq + 1; applyFinalMessagePart(this.#accumulator, part); } } -function descriptorTypeOf(descriptor: Record | undefined): string { - return typeof descriptor?.type === "string" ? descriptor.type : "com.beeper.llm"; -} - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - function errorText(error: unknown): string { if (error instanceof Error) return error.message; if (typeof error === "string") return error; diff --git a/packages/pickle/native/internal/core/appservice_test.go b/packages/pickle/native/internal/core/appservice_test.go index 9250ebc..e556af3 100644 --- a/packages/pickle/native/internal/core/appservice_test.go +++ b/packages/pickle/native/internal/core/appservice_test.go @@ -3,9 +3,15 @@ package core import ( "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" + "time" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/beeperstream" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -90,6 +96,196 @@ func TestAppserviceTransactionParsesBeeperStreamSubscribe(t *testing.T) { } } +func TestBeeperStreamClientUsesAppserviceBotDevice(t *testing.T) { + core := New(nil) + mainClient, err := mautrix.NewClient("https://matrix.example/_hungryserv/alice", id.UserID("@bot:example"), "login-token") + if err != nil { + t.Fatal(err) + } + mainClient.StateStore = mautrix.NewMemoryStateStore() + core.client = mainClient + + cli, err := core.beeperStreamClient(MatrixCoreInitOptions{ + Appservice: &MatrixAppserviceInitOptions{ + Homeserver: "https://matrix.example/_hungryserv/alice", + HomeserverDomain: "example", + Registration: MatrixAppserviceRegistration{ + AppToken: "as-token", + SenderLocalpart: "bot", + }, + }, + DeviceID: "PICKLE", + }) + if err != nil { + t.Fatal(err) + } + + if cli.UserID != id.UserID("@bot:example") { + t.Fatalf("unexpected stream user ID: %s", cli.UserID) + } + if cli.DeviceID != id.DeviceID("PICKLE") { + t.Fatalf("unexpected stream device ID: %s", cli.DeviceID) + } + if cli.AccessToken != "as-token" { + t.Fatalf("expected appservice token, got %q", cli.AccessToken) + } + if !cli.SetAppServiceUserID || !cli.SetAppServiceDeviceID { + t.Fatalf("expected appservice user and device query flags") + } + if cli.StateStore != mainClient.StateStore { + t.Fatalf("expected stream client to share state store") + } +} + +func TestCreateBeeperStreamUsesMautrixEncryptionDecision(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"event_id":"$stream"}`)) + })) + t.Cleanup(server.Close) + + core := New(nil) + cli, err := mautrix.NewClient(server.URL, id.UserID("@testbot:example"), "device-token") + if err != nil { + t.Fatal(err) + } + cli.DeviceID = id.DeviceID("PICKLE") + cli.StateStore = mautrix.NewMemoryStateStore() + core.client = cli + core.beeperStream, err = beeperstream.New(cli) + if err != nil { + t.Fatal(err) + } + + req, err := json.Marshal(MatrixStartBeeperStreamMessageOptions{ + RoomID: "!room:example", + StreamType: "com.beeper.llm", + }) + if err != nil { + t.Fatal(err) + } + resp, err := core.handleStartBeeperStreamMessage(context.Background(), req) + if err != nil { + t.Fatal(err) + } + var result struct { + Descriptor event.BeeperStreamInfo `json:"descriptor"` + } + if err = json.Unmarshal(resp, &result); err != nil { + t.Fatal(err) + } + if result.Descriptor.Encryption != nil { + t.Fatal("expected unencrypted beeper stream descriptor for unencrypted room") + } + + if err = cli.StateStore.SetEncryptionEvent(context.Background(), id.RoomID("!room:example"), &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }); err != nil { + t.Fatal(err) + } + resp, err = core.handleStartBeeperStreamMessage(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if err = json.Unmarshal(resp, &result); err != nil { + t.Fatal(err) + } + if result.Descriptor.Encryption == nil { + t.Fatal("expected encrypted beeper stream descriptor") + } + if result.Descriptor.Encryption.Algorithm != id.AlgorithmBeeperStreamV1 { + t.Fatalf("unexpected stream encryption algorithm: %s", result.Descriptor.Encryption.Algorithm) + } + if len(result.Descriptor.Encryption.Key) != 32 { + t.Fatalf("unexpected stream encryption key length: %d", len(result.Descriptor.Encryption.Key)) + } +} + +func TestRegisterBeeperStreamInjectsDirectSubscribers(t *testing.T) { + requests := make(chan recordedRequest, 4) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + requests <- recordedRequest{body: string(body), path: r.URL.Path} + w.Header().Set("Content-Type", "application/json") + if strings.Contains(r.URL.Path, "/sendToDevice/") { + _, _ = w.Write([]byte(`{}`)) + } else { + _, _ = w.Write([]byte(`{"event_id":"$stream"}`)) + } + })) + t.Cleanup(server.Close) + + core := New(nil) + cli, err := mautrix.NewClient(server.URL, id.UserID("@testbot:example"), "device-token") + if err != nil { + t.Fatal(err) + } + cli.DeviceID = id.DeviceID("PICKLE") + cli.StateStore = mautrix.NewMemoryStateStore() + core.client = cli + core.beeperStream, err = beeperstream.New(cli) + if err != nil { + t.Fatal(err) + } + + if err = cli.StateStore.SetEncryptionEvent(context.Background(), id.RoomID("!room:example"), &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }); err != nil { + t.Fatal(err) + } + startReq, err := json.Marshal(MatrixStartBeeperStreamMessageOptions{ + RoomID: "!room:example", + StreamType: "com.beeper.llm", + Subscribers: []MatrixBeeperStreamSubscriber{{ + DeviceID: "DESKTOP", + UserID: "@alice:example", + }}, + }) + if err != nil { + t.Fatal(err) + } + if _, err = core.handleStartBeeperStreamMessage(context.Background(), startReq); err != nil { + t.Fatal(err) + } + + publishReq, err := json.Marshal(MatrixPublishBeeperStreamMessagePartOptions{ + EventID: "$stream", + Part: OutboundEvent{"type": "text-delta", "delta": "hi"}, + RoomID: "!room:example", + TurnID: "turn-test", + }) + if err != nil { + t.Fatal(err) + } + if _, err = core.handlePublishBeeperStreamMessagePart(context.Background(), publishReq); err != nil { + t.Fatal(err) + } + + deadline := time.After(time.Second) + for { + select { + case req := <-requests: + if !strings.Contains(req.path, "/sendToDevice/") { + continue + } + if !strings.Contains(req.path, "/sendToDevice/m.room.encrypted/") { + t.Fatalf("expected encrypted stream update sendToDevice request, got %s", req.path) + } + if !strings.Contains(req.body, "@alice:example") || !strings.Contains(req.body, "DESKTOP") { + t.Fatalf("expected desktop subscriber in sendToDevice body, got %s", req.body) + } + return + case <-deadline: + t.Fatal("timed out waiting for stream update sendToDevice request") + } + } +} + +type recordedRequest struct { + body string + path string +} + func mustJSON(t *testing.T, value any) json.RawMessage { t.Helper() raw, err := json.Marshal(value) diff --git a/packages/pickle/native/internal/core/core.go b/packages/pickle/native/internal/core/core.go index 6d85d42..6b44715 100644 --- a/packages/pickle/native/internal/core/core.go +++ b/packages/pickle/native/internal/core/core.go @@ -15,32 +15,33 @@ import ( ) type Core struct { - client *mautrix.Client - appservice *matrixAppservice - crypto *cryptohelper.CryptoHelper - cryptoStore crypto.Store - backupKey *backup.MegolmBackupKey - backupVersion id.KeyBackupVersion - beeperStream *beeperstream.Helper - appserviceProcessor *beeperStreamEventProcessor - emit func(OutboundEvent) - host RuntimeHost - nextBatch string - pickleKey []byte - pendingDecryptions []pendingDecryption - skipNextSync bool - emittedTimelineIDs map[id.EventID]struct{} - messageEdits map[id.EventID]*MatrixMessageEvent - reactions map[id.EventID]reactionSnapshot - stores *storeBundle - userID id.UserID - deviceID id.DeviceID - cryptoStatus string - mu sync.Mutex - syncMu sync.Mutex - syncLoopMu sync.Mutex - syncLoopCancel context.CancelFunc - syncLoopDone chan struct{} + client *mautrix.Client + appservice *matrixAppservice + crypto *cryptohelper.CryptoHelper + cryptoStore crypto.Store + backupKey *backup.MegolmBackupKey + backupVersion id.KeyBackupVersion + beeperStream *beeperstream.Helper + beeperStreamMessages map[id.EventID]*beeperStreamMessage + appserviceProcessor *beeperStreamEventProcessor + emit func(OutboundEvent) + host RuntimeHost + nextBatch string + pickleKey []byte + pendingDecryptions []pendingDecryption + skipNextSync bool + emittedTimelineIDs map[id.EventID]struct{} + messageEdits map[id.EventID]*MatrixMessageEvent + reactions map[id.EventID]reactionSnapshot + stores *storeBundle + userID id.UserID + deviceID id.DeviceID + cryptoStatus string + mu sync.Mutex + syncMu sync.Mutex + syncLoopMu sync.Mutex + syncLoopCancel context.CancelFunc + syncLoopDone chan struct{} } type OutboundEvent map[string]any @@ -51,11 +52,12 @@ func New(emit func(OutboundEvent), host ...RuntimeHost) *Core { runtimeHost = host[0] } return &Core{ - emit: emit, - host: runtimeHost, - emittedTimelineIDs: make(map[id.EventID]struct{}), - messageEdits: make(map[id.EventID]*MatrixMessageEvent), - reactions: make(map[id.EventID]reactionSnapshot), + emit: emit, + host: runtimeHost, + beeperStreamMessages: make(map[id.EventID]*beeperStreamMessage), + emittedTimelineIDs: make(map[id.EventID]struct{}), + messageEdits: make(map[id.EventID]*MatrixMessageEvent), + reactions: make(map[id.EventID]reactionSnapshot), } } @@ -130,14 +132,12 @@ func (c *Core) Handle(ctx context.Context, op string, payload []byte) ([]byte, e return c.handleRemoveReaction(ctx, payload) case opSendEphemeralEvent: return c.handleSendEphemeralEvent(ctx, payload) - case opCreateBeeperStream: - return c.handleCreateBeeperStream(ctx, payload) - case opRegisterBeeperStream: - return c.handleRegisterBeeperStream(ctx, payload) - case opPublishBeeperStream: - return c.handlePublishBeeperStream(ctx, payload) - case opUnsubscribeBeeperStream: - return c.handleUnsubscribeBeeperStream(payload) + case opStartBeeperStreamMessage: + return c.handleStartBeeperStreamMessage(ctx, payload) + case opPublishBeeperStreamMessagePart: + return c.handlePublishBeeperStreamMessagePart(ctx, payload) + case opFinalizeBeeperStreamMessage: + return c.handleFinalizeBeeperStreamMessage(ctx, payload) case opSetTyping: return c.handleSetTyping(ctx, payload) case opFetchMessage: diff --git a/packages/pickle/native/internal/core/init.go b/packages/pickle/native/internal/core/init.go index 1afb197..b9cb86a 100644 --- a/packages/pickle/native/internal/core/init.go +++ b/packages/pickle/native/internal/core/init.go @@ -107,7 +107,7 @@ func (c *Core) handleInit(ctx context.Context, payload []byte) ([]byte, error) { return nil, err } c.emitInitStep("crypto_ready", initStarted) - if err := c.setupBeeperStream(); err != nil { + if err := c.setupBeeperStream(req); err != nil { return nil, err } c.emitInitStep("beeper_stream_ready", initStarted) @@ -235,8 +235,8 @@ func (c *Core) emitInitStep(step string, started time.Time) { }) } -func (c *Core) setupBeeperStream() error { - cli, err := c.requireClient() +func (c *Core) setupBeeperStream(req MatrixCoreInitOptions) error { + cli, err := c.beeperStreamClient(req) if err != nil { return err } @@ -253,6 +253,25 @@ func (c *Core) setupBeeperStream() error { return nil } +func (c *Core) beeperStreamClient(req MatrixCoreInitOptions) (*mautrix.Client, error) { + if req.Appservice == nil { + return c.requireClient() + } + botUserID := id.NewUserID(req.Appservice.Registration.SenderLocalpart, req.Appservice.HomeserverDomain) + cli, err := mautrix.NewClient(req.Appservice.Homeserver, botUserID, req.Appservice.Registration.AppToken) + if err != nil { + return nil, err + } + configureHTTPClient(cli, c.host) + cli.DeviceID = id.DeviceID(req.DeviceID) + cli.SetAppServiceUserID = true + cli.SetAppServiceDeviceID = true + if c.client != nil { + cli.StateStore = c.client.StateStore + } + return cli, nil +} + func (c *Core) handleWhoami(ctx context.Context) ([]byte, error) { cli, err := c.requireClient() if err != nil { diff --git a/packages/pickle/native/internal/core/messages.go b/packages/pickle/native/internal/core/messages.go index d6746b6..f7ea1c0 100644 --- a/packages/pickle/native/internal/core/messages.go +++ b/packages/pickle/native/internal/core/messages.go @@ -38,10 +38,6 @@ type MatrixRawMessage struct { Raw any `json:"raw"` } -type MatrixCreateBeeperStreamResult struct { - Descriptor any `json:"descriptor" tstype:"{ [key: string]: unknown }"` -} - func (c *Core) handlePostMessage(ctx context.Context, payload []byte) ([]byte, error) { cli, err := c.requireClient() if err != nil { @@ -72,84 +68,273 @@ func (c *Core) handlePostMessage(ctx context.Context, payload []byte) ([]byte, e return json.Marshal(MatrixRawMessage{EventID: resp.EventID.String(), RoomID: req.RoomID, Raw: resp}) } -type MatrixCreateBeeperStreamOptions struct { +type MatrixBeeperStreamSubscriber struct { + DeviceID string `json:"deviceId"` + UserID string `json:"userId"` +} + +type MatrixStartBeeperStreamMessageOptions struct { + Body string `json:"body,omitempty"` + Content OutboundEvent `json:"content,omitempty" tstype:"{ [key: string]: unknown }"` + RoomID string `json:"roomId"` + StreamType string `json:"streamType,omitempty"` + Subscribers []MatrixBeeperStreamSubscriber `json:"subscribers,omitempty"` + ThreadRootEventID string `json:"threadRootEventId,omitempty"` + UserID string `json:"userId,omitempty"` +} + +type MatrixStartBeeperStreamMessageResult struct { + Descriptor any `json:"descriptor" tstype:"{ [key: string]: unknown }"` + EventID string `json:"eventId"` RoomID string `json:"roomId"` - StreamType string `json:"streamType,omitempty"` } -func (c *Core) handleCreateBeeperStream(ctx context.Context, payload []byte) ([]byte, error) { +type MatrixPublishBeeperStreamMessagePartOptions struct { + AgentID string `json:"agentId,omitempty"` + EventID string `json:"eventId"` + Part OutboundEvent `json:"part" tstype:"{ [key: string]: unknown }"` + RoomID string `json:"roomId"` + TurnID string `json:"turnId"` +} + +type MatrixFinalizeBeeperStreamMessageOptions struct { + Body string `json:"body,omitempty"` + Content OutboundEvent `json:"content,omitempty" tstype:"{ [key: string]: unknown }"` + EventID string `json:"eventId"` + RoomID string `json:"roomId"` + TopLevelContent OutboundEvent `json:"topLevelContent,omitempty" tstype:"{ [key: string]: unknown }"` + UserID string `json:"userId,omitempty"` +} + +type MatrixFinalizeBeeperStreamMessageResult struct { + EventID string `json:"eventId"` + ReplacementEventID string `json:"replacementEventId"` + RoomID string `json:"roomId"` + Raw any `json:"raw"` +} + +type beeperStreamMessage struct { + descriptor *event.BeeperStreamInfo + nextSeq int + roomID id.RoomID +} + +func (c *Core) handleStartBeeperStreamMessage(ctx context.Context, payload []byte) ([]byte, error) { if c.beeperStream == nil { return nil, errors.New("beeper stream helper is not initialized") } - var req MatrixCreateBeeperStreamOptions + var req MatrixStartBeeperStreamMessageOptions if err := json.Unmarshal(payload, &req); err != nil { return nil, err } + if req.RoomID == "" { + return nil, errors.New("missing beeper stream message room ID") + } if req.StreamType == "" { - req.StreamType = "com.beeper.ai.stream_event" + req.StreamType = "com.beeper.llm" } descriptor, err := c.beeperStream.NewDescriptor(ctx, id.RoomID(req.RoomID), req.StreamType) if err != nil { return nil, err } + content := copyOutboundEvent(req.Content) + if content["body"] == nil { + content["body"] = firstNonEmpty(req.Body, "...") + } + if content["msgtype"] == nil { + content["msgtype"] = "m.text" + } + content["com.beeper.stream"] = descriptor + resp, err := c.sendBeeperStreamMessageEvent(ctx, req.RoomID, req.ThreadRootEventID, req.UserID, content) + if err != nil { + return nil, err + } + eventID := id.EventID(resp.EventID.String()) + if err = c.beeperStream.Register(ctx, id.RoomID(req.RoomID), eventID, descriptor); err != nil { + return nil, err + } + c.beeperStreamMessages[eventID] = &beeperStreamMessage{ + descriptor: descriptor.Clone(), + nextSeq: 1, + roomID: id.RoomID(req.RoomID), + } + c.addBeeperStreamSubscribers(ctx, id.RoomID(req.RoomID), eventID, req.Subscribers) c.client.Log.Debug(). Str("stream_type", descriptor.Type). Stringer("room_id", id.RoomID(req.RoomID)). + Stringer("event_id", eventID). Stringer("user_id", descriptor.UserID). Stringer("device_id", descriptor.DeviceID). Bool("encrypted", descriptor.Encryption != nil). - Msg("Created beeper stream descriptor") - return json.Marshal(MatrixCreateBeeperStreamResult{Descriptor: descriptor}) -} - -type MatrixBeeperStreamOptions struct { - Content map[string]any `json:"content,omitempty"` - EventID string `json:"eventId"` - RoomID string `json:"roomId"` -} - -type MatrixRegisterBeeperStreamOptions struct { - Descriptor json.RawMessage `json:"descriptor" tstype:"{ [key: string]: unknown }"` - EventID string `json:"eventId"` - RoomID string `json:"roomId"` - Subscribers []MatrixBeeperStreamSubscriber `json:"subscribers,omitempty"` + Int("direct_subscribers", len(req.Subscribers)). + Msg("Started beeper stream message") + return json.Marshal(MatrixStartBeeperStreamMessageResult{ + Descriptor: descriptor, + EventID: eventID.String(), + RoomID: req.RoomID, + }) } -type MatrixBeeperStreamSubscriber struct { - DeviceID string `json:"deviceId"` - UserID string `json:"userId"` +func (c *Core) sendBeeperStreamMessageEvent(ctx context.Context, roomID, threadRootEventID, userID string, content OutboundEvent) (*mautrix.RespSendEvent, error) { + if threadRootEventID != "" && content["m.relates_to"] == nil { + content["m.relates_to"] = (&event.RelatesTo{}).SetThread(id.EventID(threadRootEventID), "") + } + if userID != "" { + intent, err := c.requireAppserviceIntent(userID) + if err != nil { + return nil, err + } + if err = c.appservice.ensureJoined(ctx, intent, id.RoomID(roomID)); err != nil { + return nil, err + } + return intent.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + } + cli, err := c.requireClient() + if err != nil { + return nil, err + } + return retryMatrix(ctx, func() (*mautrix.RespSendEvent, error) { + if err := c.prepareOutboundMegolm(ctx, cli, id.RoomID(roomID)); err != nil { + return nil, err + } + return cli.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + }) } -func (c *Core) handleRegisterBeeperStream(ctx context.Context, payload []byte) ([]byte, error) { +func (c *Core) handlePublishBeeperStreamMessagePart(ctx context.Context, payload []byte) ([]byte, error) { if c.beeperStream == nil { return nil, errors.New("beeper stream helper is not initialized") } - var req MatrixRegisterBeeperStreamOptions + var req MatrixPublishBeeperStreamMessagePartOptions if err := json.Unmarshal(payload, &req); err != nil { return nil, err } - if req.RoomID == "" || req.EventID == "" || len(req.Descriptor) == 0 { - return nil, errors.New("missing beeper stream registration fields") + stream := c.beeperStreamMessages[id.EventID(req.EventID)] + if stream == nil { + return nil, fmt.Errorf("beeper stream message %s is not registered", req.EventID) } - var descriptor event.BeeperStreamInfo - if err := json.Unmarshal(req.Descriptor, &descriptor); err != nil { - return nil, err + if req.RoomID != "" && stream.roomID != id.RoomID(req.RoomID) { + return nil, fmt.Errorf("beeper stream message %s is registered in %s, not %s", req.EventID, stream.roomID, req.RoomID) + } + if req.TurnID == "" { + return nil, errors.New("missing beeper stream message turn ID") } - if err := c.beeperStream.Register(ctx, id.RoomID(req.RoomID), id.EventID(req.EventID), &descriptor); err != nil { + streamType := stream.descriptor.Type + if streamType == "" { + streamType = "com.beeper.llm" + } + seq := stream.nextSeq + delta := map[string]any{ + "m.relates_to": &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(req.EventID)}, + "part": req.Part, + "seq": seq, + "turn_id": req.TurnID, + } + if req.AgentID != "" { + delta["agent_id"] = req.AgentID + } + content := map[string]any{ + streamType + ".deltas": []any{delta}, + } + if err := c.beeperStream.Publish(ctx, stream.roomID, id.EventID(req.EventID), content); err != nil { return nil, err } - c.addBeeperStreamSubscribers(ctx, id.RoomID(req.RoomID), id.EventID(req.EventID), req.Subscribers) - c.client.Log.Debug(). - Str("stream_type", descriptor.Type). - Stringer("room_id", id.RoomID(req.RoomID)). - Stringer("event_id", id.EventID(req.EventID)). - Stringer("user_id", descriptor.UserID). - Stringer("device_id", descriptor.DeviceID). - Int("direct_subscribers", len(req.Subscribers)). - Msg("Registered beeper stream") + stream.nextSeq = seq + 1 return c.empty() } +func (c *Core) handleFinalizeBeeperStreamMessage(ctx context.Context, payload []byte) ([]byte, error) { + var req MatrixFinalizeBeeperStreamMessageOptions + if err := json.Unmarshal(payload, &req); err != nil { + return nil, err + } + if req.RoomID == "" || req.EventID == "" { + return nil, errors.New("missing beeper stream finalize fields") + } + content := copyOutboundEvent(req.Content) + if content["body"] == nil { + content["body"] = firstNonEmpty(req.Body, "...") + } + if content["msgtype"] == nil { + content["msgtype"] = "m.text" + } + content["com.beeper.stream"] = nil + topLevel := copyOutboundEvent(req.TopLevelContent) + topLevel["com.beeper.stream"] = nil + replacement, err := c.sendBeeperStreamReplacementEvent(ctx, req.RoomID, req.EventID, req.UserID, content, topLevel) + if err != nil { + return nil, err + } + targetEventID := id.EventID(req.EventID) + if c.beeperStream != nil { + c.beeperStream.Unregister(id.RoomID(req.RoomID), targetEventID) + c.beeperStream.Unsubscribe(id.RoomID(req.RoomID), targetEventID) + } + delete(c.beeperStreamMessages, targetEventID) + return json.Marshal(MatrixFinalizeBeeperStreamMessageResult{ + EventID: req.EventID, + ReplacementEventID: replacement.EventID.String(), + RoomID: req.RoomID, + Raw: replacement, + }) +} + +func (c *Core) sendBeeperStreamReplacementEvent(ctx context.Context, roomID, eventID, userID string, newContent, topLevel OutboundEvent) (*mautrix.RespSendEvent, error) { + content := copyOutboundEvent(topLevel) + content["body"] = "" + content["msgtype"] = firstString(newContent["msgtype"], "m.text") + content["m.new_content"] = newContent + content["m.relates_to"] = map[string]any{ + "event_id": eventID, + "rel_type": "m.replace", + } + if userID != "" { + intent, err := c.requireAppserviceIntent(userID) + if err != nil { + return nil, err + } + if err = c.appservice.ensureJoined(ctx, intent, id.RoomID(roomID)); err != nil { + return nil, err + } + return intent.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + } + cli, err := c.requireClient() + if err != nil { + return nil, err + } + resp, err := retryMatrix(ctx, func() (*mautrix.RespSendEvent, error) { + if err := c.prepareOutboundMegolm(ctx, cli, id.RoomID(roomID)); err != nil { + return nil, err + } + return cli.SendMessageEvent(ctx, id.RoomID(roomID), event.EventMessage, content) + }) + if err != nil { + return nil, err + } + now := time.Now().UnixMilli() + isMe := true + isEdited := true + replaces := eventID + c.rememberEdit(&MatrixMessageEvent{ + MatrixRawEvent: MatrixRawEvent{ + Content: newContent, + EventID: eventID, + IsMe: &isMe, + OriginServerTS: &now, + Raw: resp, + RoomID: roomID, + Sender: c.userID.String(), + Type: event.EventMessage.Type, + }, + Body: firstString(newContent["body"], ""), + IsEdited: &isEdited, + Msgtype: firstString(newContent["msgtype"], "m.text"), + Relation: &MatrixRelation{EventID: eventID, Type: string(event.RelReplace)}, + Replaces: &replaces, + }) + return resp, nil +} + func (c *Core) addBeeperStreamSubscribers(ctx context.Context, roomID id.RoomID, eventID id.EventID, subscribers []MatrixBeeperStreamSubscriber) { if c.beeperStream == nil || c.client == nil || len(subscribers) == 0 { return @@ -180,50 +365,11 @@ func (c *Core) addBeeperStreamSubscribers(ctx context.Context, roomID id.RoomID, }) } -func (c *Core) handlePublishBeeperStream(ctx context.Context, payload []byte) ([]byte, error) { - if c.beeperStream == nil { - return nil, errors.New("beeper stream helper is not initialized") - } - var req MatrixBeeperStreamOptions - if err := json.Unmarshal(payload, &req); err != nil { - return nil, err - } - if err := c.beeperStream.Publish(ctx, id.RoomID(req.RoomID), id.EventID(req.EventID), req.Content); err != nil { - return nil, err - } - trace := beeperStreamUpdateTrace(req.Content) - c.client.Log.Debug(). - Int("delta_count", trace.DeltaCount). - Interface("first_seq", trace.FirstSeq). - Str("first_part_type", trace.FirstPartType). - Str("first_target_event", trace.FirstTargetEvent). - Str("first_turn_id", trace.FirstTurnID). - Int("keys", len(req.Content)). - Stringer("room_id", id.RoomID(req.RoomID)). - Stringer("event_id", id.EventID(req.EventID)). - Msg("Published beeper stream update") - return c.empty() -} - -func (c *Core) handleUnsubscribeBeeperStream(payload []byte) ([]byte, error) { - if c.beeperStream == nil { - return c.empty() - } - var req MatrixBeeperStreamOptions - if err := json.Unmarshal(payload, &req); err != nil { - return nil, err - } - c.beeperStream.Unregister(id.RoomID(req.RoomID), id.EventID(req.EventID)) - c.beeperStream.Unsubscribe(id.RoomID(req.RoomID), id.EventID(req.EventID)) - return c.empty() -} - type beeperStreamUpdateTraceData struct { - DeltaCount int - FirstPartType string - FirstSeq any - FirstTargetEvent string - FirstTurnID string + DeltaCount int + FirstPartType string + FirstSeq any + FirstTurnID string } func beeperStreamUpdateTrace(content map[string]any) beeperStreamUpdateTraceData { @@ -258,9 +404,6 @@ func beeperStreamUpdateTrace(content map[string]any) beeperStreamUpdateTraceData if turnID, ok := delta["turn_id"].(string); ok { trace.FirstTurnID = turnID } - if targetEvent, ok := delta["target_event"].(string); ok { - trace.FirstTargetEvent = targetEvent - } if part, ok := delta["part"].(map[string]any); ok { if partType, ok := part["type"].(string); ok { trace.FirstPartType = partType @@ -277,10 +420,33 @@ func (trace *beeperStreamUpdateTraceData) merge(next beeperStreamUpdateTraceData } trace.FirstSeq = next.FirstSeq trace.FirstPartType = next.FirstPartType - trace.FirstTargetEvent = next.FirstTargetEvent trace.FirstTurnID = next.FirstTurnID } +func copyOutboundEvent(input OutboundEvent) OutboundEvent { + output := make(OutboundEvent, len(input)) + for key, value := range input { + output[key] = value + } + return output +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func firstString(value any, fallback string) string { + if str, ok := value.(string); ok && str != "" { + return str + } + return fallback +} + type MatrixEditMessageOptions struct { RoomID string `json:"roomId"` MessageID string `json:"messageId"` diff --git a/packages/pickle/native/internal/core/operations.go b/packages/pickle/native/internal/core/operations.go index 635df51..5e9c608 100644 --- a/packages/pickle/native/internal/core/operations.go +++ b/packages/pickle/native/internal/core/operations.go @@ -63,14 +63,12 @@ const ( opRemoveReaction = "remove_reaction" // ts:operation sendEphemeralEvent send_ephemeral_event MatrixSendEphemeralEventOptions MatrixRawMessage opSendEphemeralEvent = "send_ephemeral_event" - // ts:operation createBeeperStream create_beeper_stream MatrixCreateBeeperStreamOptions MatrixCreateBeeperStreamResult - opCreateBeeperStream = "create_beeper_stream" - // ts:operation registerBeeperStream register_beeper_stream MatrixRegisterBeeperStreamOptions void - opRegisterBeeperStream = "register_beeper_stream" - // ts:operation publishBeeperStream publish_beeper_stream MatrixBeeperStreamOptions void - opPublishBeeperStream = "publish_beeper_stream" - // ts:operation unsubscribeBeeperStream unsubscribe_beeper_stream MatrixBeeperStreamOptions void - opUnsubscribeBeeperStream = "unsubscribe_beeper_stream" + // ts:operation startBeeperStreamMessage start_beeper_stream_message MatrixStartBeeperStreamMessageOptions MatrixStartBeeperStreamMessageResult + opStartBeeperStreamMessage = "start_beeper_stream_message" + // ts:operation publishBeeperStreamMessagePart publish_beeper_stream_message_part MatrixPublishBeeperStreamMessagePartOptions void + opPublishBeeperStreamMessagePart = "publish_beeper_stream_message_part" + // ts:operation finalizeBeeperStreamMessage finalize_beeper_stream_message MatrixFinalizeBeeperStreamMessageOptions MatrixFinalizeBeeperStreamMessageResult + opFinalizeBeeperStreamMessage = "finalize_beeper_stream_message" // ts:operation setTyping set_typing MatrixTypingOptions void opSetTyping = "set_typing" // ts:operation fetchMessage fetch_message MatrixFetchMessageOptions MatrixFetchMessageResult diff --git a/packages/pickle/src/client-types.ts b/packages/pickle/src/client-types.ts index c5691a6..d89bd6b 100644 --- a/packages/pickle/src/client-types.ts +++ b/packages/pickle/src/client-types.ts @@ -3,7 +3,6 @@ import type { AccountDataOptions, AccountDataResult, BanUserOptions, - CreateBeeperStreamOptions, CreateRoomOptions, CreateRoomResult, DownloadEncryptedMediaOptions, @@ -39,10 +38,8 @@ import type { OpenDMResult, OwnAvatarUrlResult, OwnDisplayNameResult, - PublishBeeperStreamOptions, ReactionOptions, RedactMessageOptions, - RegisterBeeperStreamOptions, ResolveRoomAliasOptions, ResolveRoomAliasResult, RawRequestOptions, @@ -81,6 +78,11 @@ import type { MatrixAppserviceRoomUserOptions, MatrixAppserviceSendMessageOptions, MatrixAppserviceUserOptions, + MatrixFinalizeBeeperStreamMessageOptions, + MatrixFinalizeBeeperStreamMessageResult, + MatrixPublishBeeperStreamMessagePartOptions, + MatrixStartBeeperStreamMessageOptions, + MatrixStartBeeperStreamMessageResult, } from "./runtime-types"; export interface MatrixClient { @@ -146,9 +148,9 @@ export interface MatrixBeeper { send(options: SendBeeperEphemeralOptions): Promise; }; streams: { - create(options: CreateBeeperStreamOptions): Promise<{ descriptor: Record }>; - publish(options: PublishBeeperStreamOptions): Promise; - register(options: RegisterBeeperStreamOptions): Promise; + finalizeMessage(options: MatrixFinalizeBeeperStreamMessageOptions): Promise; + publishPart(options: MatrixPublishBeeperStreamMessagePartOptions): Promise; + startMessage(options: MatrixStartBeeperStreamMessageOptions): Promise; }; } diff --git a/packages/pickle/src/client.test.ts b/packages/pickle/src/client.test.ts index 9df1489..0d946c9 100644 --- a/packages/pickle/src/client.test.ts +++ b/packages/pickle/src/client.test.ts @@ -268,9 +268,9 @@ describe("createMatrixClient", () => { event: { content: { "com.beeper.llm.deltas": [{ + "m.relates_to": { event_id: "$message", rel_type: "m.reference" }, part: { delta: "hi", id: "text", type: "text-delta" }, seq: 1, - target_event: "$message", turn_id: "turn_1", }], event_id: "$message", @@ -848,18 +848,18 @@ describe("createMatrixClient", () => { it("streams with Beeper stream events on Beeper homeservers", async () => { const calls = installRuntime({ - create_beeper_stream: { + finalize_beeper_stream_message: { eventId: "$message", raw: {}, replacementEventId: "$edit", roomId: "!room:example.com" }, + init: { deviceId: "DEVICE", userId: "@bot:example.com" }, + publish_beeper_stream_message_part: {}, + start_beeper_stream_message: { descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example.com", }, + eventId: "$message", + roomId: "!room:example.com", }, - edit_message: { eventId: "$edit", raw: {}, roomId: "!room:example.com" }, - init: { deviceId: "DEVICE", userId: "@bot:example.com" }, - post_message: { eventId: "$message", raw: {}, roomId: "!room:example.com" }, - publish_beeper_stream: {}, - register_beeper_stream: {}, }); const client = createMatrixClient({ homeserver: "https://matrix.beeper.com", @@ -880,52 +880,37 @@ describe("createMatrixClient", () => { }); expect(calls.map((call) => call.operation)).toEqual([ "init", - "create_beeper_stream", - "post_message", - "register_beeper_stream", - "publish_beeper_stream", - "publish_beeper_stream", - "publish_beeper_stream", - "publish_beeper_stream", - "publish_beeper_stream", - "publish_beeper_stream", - "edit_message", + "start_beeper_stream_message", + "publish_beeper_stream_message_part", + "publish_beeper_stream_message_part", + "publish_beeper_stream_message_part", + "publish_beeper_stream_message_part", + "publish_beeper_stream_message_part", + "publish_beeper_stream_message_part", + "finalize_beeper_stream_message", ]); - expect(calls[2]?.payload).toMatchObject({ - body: "...", + expect(calls[1]?.payload).toMatchObject({ content: { "com.beeper.ai": { id: expect.any(String), parts: [], role: "assistant", }, - "com.beeper.stream": { - type: "com.beeper.llm", - }, }, roomId: "!room:example.com", + streamType: "com.beeper.llm", threadRootEventId: "$thread", }); - expect(calls[3]?.payload).toMatchObject({ - eventId: "$message", - roomId: "!room:example.com", - }); - expect(calls[4]?.payload).toMatchObject({ - content: { - "com.beeper.llm.deltas": [{ - part: { - messageMetadata: expect.any(Object), - type: "start", - }, - seq: 1, - target_event: "$message", - turn_id: expect.any(String), - }], + expect(calls[2]?.payload).toMatchObject({ + part: { + messageMetadata: expect.any(Object), + type: "start", }, eventId: "$message", roomId: "!room:example.com", + turnId: expect.any(String), }); - expect(calls[10]?.payload).toMatchObject({ + expect(calls[8]?.payload).toMatchObject({ body: "hello", content: { "com.beeper.ai": { @@ -933,31 +918,29 @@ describe("createMatrixClient", () => { parts: [{ text: "hello", type: "text" }], role: "assistant", }, - "com.beeper.stream": null, }, - messageId: "$message", + eventId: "$message", roomId: "!room:example.com", topLevelContent: { "com.beeper.dont_render_edited": true, - "com.beeper.stream": null, }, }); }); it("uses explicit Beeper capability for non-Beeper hostnames", async () => { const calls = installRuntime({ - create_beeper_stream: { + finalize_beeper_stream_message: { eventId: "$message", raw: {}, replacementEventId: "$edit", roomId: "!room:example.com" }, + init: { deviceId: "DEVICE", userId: "@bot:example.com" }, + publish_beeper_stream_message_part: {}, + start_beeper_stream_message: { descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example.com", }, + eventId: "$message", + roomId: "!room:example.com", }, - edit_message: { eventId: "$edit", raw: {}, roomId: "!room:example.com" }, - init: { deviceId: "DEVICE", userId: "@bot:example.com" }, - post_message: { eventId: "$message", raw: {}, roomId: "!room:example.com" }, - publish_beeper_stream: {}, - register_beeper_stream: {}, }); const client = createMatrixClient({ beeper: true, @@ -970,24 +953,24 @@ describe("createMatrixClient", () => { stream: chunks("hello"), }); - expect(calls.map((call) => call.operation)).toContain("create_beeper_stream"); - expect(calls.map((call) => call.operation)).toContain("publish_beeper_stream"); + expect(calls.map((call) => call.operation)).toContain("start_beeper_stream_message"); + expect(calls.map((call) => call.operation)).toContain("publish_beeper_stream_message_part"); }); it("keeps accumulated UI message parts in the Beeper final edit", async () => { const calls = installRuntime({ - create_beeper_stream: { + finalize_beeper_stream_message: { eventId: "$message", raw: {}, replacementEventId: "$edit", roomId: "!room:example.com" }, + init: { deviceId: "DEVICE", userId: "@bot:example.com" }, + publish_beeper_stream_message_part: {}, + start_beeper_stream_message: { descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example.com", }, + eventId: "$message", + roomId: "!room:example.com", }, - edit_message: { eventId: "$edit", raw: {}, roomId: "!room:example.com" }, - init: { deviceId: "DEVICE", userId: "@bot:example.com" }, - post_message: { eventId: "$message", raw: {}, roomId: "!room:example.com" }, - publish_beeper_stream: {}, - register_beeper_stream: {}, }); const client = createMatrixClient({ homeserver: "https://matrix.beeper.com", @@ -1009,7 +992,7 @@ describe("createMatrixClient", () => { ), }); - const edit = calls.find((call) => call.operation === "edit_message")?.payload; + const edit = calls.find((call) => call.operation === "finalize_beeper_stream_message")?.payload; expect(edit).toMatchObject({ body: "hello", content: { @@ -1031,18 +1014,18 @@ describe("createMatrixClient", () => { it("lets callers override the Beeper final AI message", async () => { const calls = installRuntime({ - create_beeper_stream: { + finalize_beeper_stream_message: { eventId: "$message", raw: {}, replacementEventId: "$edit", roomId: "!room:example.com" }, + init: { deviceId: "DEVICE", userId: "@bot:example.com" }, + publish_beeper_stream_message_part: {}, + start_beeper_stream_message: { descriptor: { device_id: "DEVICE", type: "com.beeper.llm", user_id: "@bot:example.com", }, + eventId: "$message", + roomId: "!room:example.com", }, - edit_message: { eventId: "$edit", raw: {}, roomId: "!room:example.com" }, - init: { deviceId: "DEVICE", userId: "@bot:example.com" }, - post_message: { eventId: "$message", raw: {}, roomId: "!room:example.com" }, - publish_beeper_stream: {}, - register_beeper_stream: {}, }); const client = createMatrixClient({ homeserver: "https://matrix.beeper.com", @@ -1061,7 +1044,7 @@ describe("createMatrixClient", () => { stream: chunks("ignored"), }); - const edit = calls.find((call) => call.operation === "edit_message")?.payload; + const edit = calls.find((call) => call.operation === "finalize_beeper_stream_message")?.payload; expect(edit).toMatchObject({ body: "override", content: { diff --git a/packages/pickle/src/client.ts b/packages/pickle/src/client.ts index 831ae55..9f8e76b 100644 --- a/packages/pickle/src/client.ts +++ b/packages/pickle/src/client.ts @@ -96,9 +96,9 @@ class DefaultMatrixClient implements MatrixClient { }))), }, streams: { - create: (opts) => this.#withCore((core) => core.createBeeperStream(opts)), - publish: (opts) => this.#withCore((core) => core.publishBeeperStream(opts)), - register: (opts) => this.#withCore((core) => core.registerBeeperStream(opts)), + finalizeMessage: (opts) => this.#withCore((core) => core.finalizeBeeperStreamMessage(stripUndefined(opts))), + publishPart: (opts) => this.#withCore((core) => core.publishBeeperStreamMessagePart(stripUndefined(opts))), + startMessage: (opts) => this.#withCore((core) => core.startBeeperStreamMessage(stripUndefined(opts))), }, }; this.crypto = { diff --git a/packages/pickle/src/generated-runtime-operations.ts b/packages/pickle/src/generated-runtime-operations.ts index 4d6b18d..34872e6 100644 --- a/packages/pickle/src/generated-runtime-operations.ts +++ b/packages/pickle/src/generated-runtime-operations.ts @@ -15,10 +15,7 @@ import type { MatrixAppserviceTransactionOptions, MatrixAppserviceUserOptions, MatrixBanUserOptions, - MatrixBeeperStreamOptions, MatrixCoreInitOptions, - MatrixCreateBeeperStreamOptions, - MatrixCreateBeeperStreamResult, MatrixCreateRoomOptions, MatrixCreateRoomResult, MatrixCryptoStatus, @@ -37,6 +34,8 @@ import type { MatrixFetchRoomStateEventOptions, MatrixFetchRoomStateOptions, MatrixFetchRoomStateResult, + MatrixFinalizeBeeperStreamMessageOptions, + MatrixFinalizeBeeperStreamMessageResult, MatrixGetAccountDataOptions, MatrixGetRoomAccountDataOptions, MatrixGetUserOptions, @@ -55,11 +54,11 @@ import type { MatrixOpenDMResult, MatrixOwnAvatarURLResult, MatrixOwnDisplayNameResult, + MatrixPublishBeeperStreamMessagePartOptions, MatrixRawMessage, MatrixRawRequestOptions, MatrixRawRequestResult, MatrixReactionOptions, - MatrixRegisterBeeperStreamOptions, MatrixResolveRoomAliasOptions, MatrixResolveRoomAliasResult, MatrixRoomInfo, @@ -76,6 +75,8 @@ import type { MatrixSetOwnAvatarURLOptions, MatrixSetOwnDisplayNameOptions, MatrixSetRoomAccountDataOptions, + MatrixStartBeeperStreamMessageOptions, + MatrixStartBeeperStreamMessageResult, MatrixSyncOnceOptions, MatrixSyncStartOptions, MatrixTypingOptions, @@ -119,10 +120,9 @@ export interface MatrixCoreOperations { addReaction(options: MatrixReactionOptions): Promise; removeReaction(options: MatrixReactionOptions): Promise; sendEphemeralEvent(options: MatrixSendEphemeralEventOptions): Promise; - createBeeperStream(options: MatrixCreateBeeperStreamOptions): Promise; - registerBeeperStream(options: MatrixRegisterBeeperStreamOptions): Promise; - publishBeeperStream(options: MatrixBeeperStreamOptions): Promise; - unsubscribeBeeperStream(options: MatrixBeeperStreamOptions): Promise; + startBeeperStreamMessage(options: MatrixStartBeeperStreamMessageOptions): Promise; + publishBeeperStreamMessagePart(options: MatrixPublishBeeperStreamMessagePartOptions): Promise; + finalizeBeeperStreamMessage(options: MatrixFinalizeBeeperStreamMessageOptions): Promise; setTyping(options: MatrixTypingOptions): Promise; fetchMessage(options: MatrixFetchMessageOptions): Promise; fetchMessages(options: MatrixFetchMessagesOptions): Promise; @@ -284,20 +284,16 @@ export abstract class MatrixCoreOperationCaller implements MatrixCoreOperations return this.call("send_ephemeral_event", options); } - createBeeperStream(options: MatrixCreateBeeperStreamOptions): Promise { - return this.call("create_beeper_stream", options); + startBeeperStreamMessage(options: MatrixStartBeeperStreamMessageOptions): Promise { + return this.call("start_beeper_stream_message", options); } - registerBeeperStream(options: MatrixRegisterBeeperStreamOptions): Promise { - return this.call("register_beeper_stream", options); + publishBeeperStreamMessagePart(options: MatrixPublishBeeperStreamMessagePartOptions): Promise { + return this.call("publish_beeper_stream_message_part", options); } - publishBeeperStream(options: MatrixBeeperStreamOptions): Promise { - return this.call("publish_beeper_stream", options); - } - - unsubscribeBeeperStream(options: MatrixBeeperStreamOptions): Promise { - return this.call("unsubscribe_beeper_stream", options); + finalizeBeeperStreamMessage(options: MatrixFinalizeBeeperStreamMessageOptions): Promise { + return this.call("finalize_beeper_stream_message", options); } setTyping(options: MatrixTypingOptions): Promise { diff --git a/packages/pickle/src/generated-runtime-types.ts b/packages/pickle/src/generated-runtime-types.ts index 2b6c6a2..d963d02 100644 --- a/packages/pickle/src/generated-runtime-types.ts +++ b/packages/pickle/src/generated-runtime-types.ts @@ -201,19 +201,20 @@ export interface MatrixRawMessage { roomId: string; raw: unknown; } -export interface MatrixCreateBeeperStreamResult { - descriptor: { [key: string]: unknown }; +export interface MatrixBeeperStreamSubscriber { + deviceId: string; + userId: string; } -export interface MatrixCreateBeeperStreamOptions { +export interface MatrixStartBeeperStreamMessageOptions { + body?: string; + content?: { [key: string]: unknown }; roomId: string; streamType?: string; + subscribers?: MatrixBeeperStreamSubscriber[]; + threadRootEventId?: string; + userId?: string; } -export interface MatrixBeeperStreamOptions { - content?: { [key: string]: unknown}; - eventId: string; - roomId: string; -} -export interface MatrixRegisterBeeperStreamOptions { +export interface MatrixStartBeeperStreamMessageResult { descriptor: { [key: string]: unknown }; eventId: string; roomId: string; @@ -223,6 +224,27 @@ export interface MatrixBeeperStreamSubscriber { deviceId: string; userId: string; } +export interface MatrixPublishBeeperStreamMessagePartOptions { + agentId?: string; + eventId: string; + part: { [key: string]: unknown }; + roomId: string; + turnId: string; +} +export interface MatrixFinalizeBeeperStreamMessageOptions { + body?: string; + content?: { [key: string]: unknown }; + eventId: string; + roomId: string; + topLevelContent?: { [key: string]: unknown }; + userId?: string; +} +export interface MatrixFinalizeBeeperStreamMessageResult { + eventId: string; + replacementEventId: string; + roomId: string; + raw: unknown; +} export interface MatrixEditMessageOptions { roomId: string; messageId: string; diff --git a/packages/pickle/src/index.ts b/packages/pickle/src/index.ts index 2dd1df5..8201d32 100644 --- a/packages/pickle/src/index.ts +++ b/packages/pickle/src/index.ts @@ -41,7 +41,6 @@ export type { AccountDataOptions, AccountDataResult, BanUserOptions, - CreateBeeperStreamOptions, CreateRoomOptions, CreateRoomResult, DownloadEncryptedMediaOptions, @@ -68,7 +67,6 @@ export type { ListThreadsResult, MarkReadOptions, MatrixAttachment, - MatrixBeeperStreamDescriptor, MatrixAccount, MatrixBaseEvent, MatrixClientEvent, @@ -100,10 +98,8 @@ export type { OpenDMResult, OwnAvatarUrlResult, OwnDisplayNameResult, - PublishBeeperStreamOptions, ReactionOptions, RedactMessageOptions, - RegisterBeeperStreamOptions, ResolveRoomAliasOptions, ResolveRoomAliasResult, RawRequestOptions, diff --git a/packages/pickle/src/runtime-types.ts b/packages/pickle/src/runtime-types.ts index f73e818..39bb3a2 100644 --- a/packages/pickle/src/runtime-types.ts +++ b/packages/pickle/src/runtime-types.ts @@ -27,11 +27,8 @@ export type { MatrixAppserviceUserOptions, MatrixApplySyncResponseOptions, MatrixBanUserOptions, - MatrixBeeperStreamOptions, MatrixCoreInitOptions, MatrixCryptoStatus, - MatrixCreateBeeperStreamOptions, - MatrixCreateBeeperStreamResult, MatrixCreateRoomOptions, MatrixCreateRoomResult, MatrixDeleteMessageOptions, @@ -42,6 +39,8 @@ export type { MatrixEditMessageOptions, MatrixEncryptedFile, MatrixEncryptedFileKey, + MatrixFinalizeBeeperStreamMessageOptions, + MatrixFinalizeBeeperStreamMessageResult, MatrixFetchMessageOptions, MatrixFetchMessageResult, MatrixFetchMessagesOptions, @@ -74,13 +73,13 @@ export type { MatrixOpenDMResult, MatrixOwnAvatarURLResult, MatrixOwnDisplayNameResult, + MatrixPublishBeeperStreamMessagePartOptions, MatrixRawEvent, MatrixRawMessage, MatrixRawRequestOptions, MatrixRawRequestResult, MatrixReactionEvent, MatrixReactionOptions, - MatrixRegisterBeeperStreamOptions, MatrixResolveRoomAliasOptions, MatrixResolveRoomAliasResult, MatrixRoomInfo, @@ -100,6 +99,8 @@ export type { MatrixSetOwnDisplayNameOptions, MatrixSetAccountDataOptions, MatrixSetRoomAccountDataOptions, + MatrixStartBeeperStreamMessageOptions, + MatrixStartBeeperStreamMessageResult, MatrixSyncOnceOptions, MatrixSyncEvent, MatrixSyncStartOptions, diff --git a/packages/pickle/src/streams/beeper.ts b/packages/pickle/src/streams/beeper.ts index 6dd89c5..fa6ee29 100644 --- a/packages/pickle/src/streams/beeper.ts +++ b/packages/pickle/src/streams/beeper.ts @@ -1,6 +1,6 @@ -import type { MatrixBeeper, MatrixMessages } from "../client-types"; +import type { MatrixBeeper } from "../client-types"; import { stripUndefined } from "../object"; -import type { SendMatrixStreamOptions, SendMessageOptions, SentEvent } from "../types"; +import type { SendMatrixStreamOptions, SentEvent } from "../types"; import { applyFinalMessagePart, compactFinalContent, @@ -13,59 +13,33 @@ import { streamChunkText } from "./edits"; export async function sendBeeperStream( client: { beeper: MatrixBeeper; - messages: MatrixMessages; }, opts: SendMatrixStreamOptions ): Promise { - const stream = await client.beeper.streams.create({ - roomId: opts.roomId, - streamType: "com.beeper.llm", - }); const turnId = `turn_${Date.now().toString(36)}_${Math.random().toString(36).slice(2, 10)}`; - const targetOptions: SendMessageOptions = { + const target = await client.beeper.streams.startMessage(stripUndefined({ content: { body: "...", "com.beeper.ai": { id: turnId, metadata: { turn_id: turnId }, parts: [], role: "assistant" }, - "com.beeper.stream": stream.descriptor, msgtype: "m.text", }, - messageType: "m.text" as const, - roomId: opts.roomId, - text: "...", - ...(opts.threadRoot === undefined ? {} : { threadRoot: opts.threadRoot }), - }; - const target = await client.messages.send(targetOptions); - await client.beeper.streams.register({ - descriptor: stream.descriptor, - eventId: target.eventId, roomId: opts.roomId, - }); + streamType: "com.beeper.llm", + threadRootEventId: opts.threadRoot, + })); const textId = `text_${turnId}`; const accumulator = createFinalMessageAccumulator(turnId); - let seq = 1; let textOpen = false; let sawFinish = false; - const pendingPublishes = new Set>(); - const publishPart = (part: Record) => { - const publish = publishBeeperStreamPart(client.beeper, opts.roomId, target.eventId, stream.descriptor, turnId, seq++, part) - .catch((error) => { - console.warn("[pickle] failed to publish beeper stream part", error); - }) - .finally(() => { - pendingPublishes.delete(publish); - }); - pendingPublishes.add(publish); - }; - const waitForPublishes = async () => { - while (pendingPublishes.size) await Promise.all([...pendingPublishes]); - }; + const publishPart = (part: Record) => + client.beeper.streams.publishPart({ eventId: target.eventId, part, roomId: opts.roomId, turnId }); const startPart = { messageId: turnId, messageMetadata: { turn_id: turnId }, type: "start", }; applyFinalMessagePart(accumulator, startPart); - publishPart(startPart); + await publishPart(startPart); for await (const chunk of opts.stream) { const normalizedChunks = normalizeRichStreamChunk(chunk); if (normalizedChunks.length > 0) { @@ -73,7 +47,7 @@ export async function sendBeeperStream( const type = typeof normalizedChunk.type === "string" ? normalizedChunk.type : ""; if (type === "finish" || type === "error" || type === "abort") sawFinish = true; applyFinalMessagePart(accumulator, normalizedChunk); - publishPart(normalizedChunk); + await publishPart(normalizedChunk); } continue; } @@ -81,7 +55,7 @@ export async function sendBeeperStream( const type = typeof chunk.type === "string" ? chunk.type : ""; if (type === "finish" || type === "error" || type === "abort") sawFinish = true; applyFinalMessagePart(accumulator, chunk); - publishPart(chunk); + await publishPart(chunk); continue; } const text = streamChunkText(chunk); @@ -90,10 +64,10 @@ export async function sendBeeperStream( const textStartPart = { id: textId, type: "text-start", - }; - applyFinalMessagePart(accumulator, textStartPart); - publishPart(textStartPart); - textOpen = true; + }; + applyFinalMessagePart(accumulator, textStartPart); + await publishPart(textStartPart); + textOpen = true; } const textDeltaPart = { delta: text, @@ -101,7 +75,7 @@ export async function sendBeeperStream( type: "text-delta", }; applyFinalMessagePart(accumulator, textDeltaPart); - publishPart(textDeltaPart); + await publishPart(textDeltaPart); } if (textOpen) { const textEndPart = { @@ -109,7 +83,7 @@ export async function sendBeeperStream( type: "text-end", }; applyFinalMessagePart(accumulator, textEndPart); - publishPart(textEndPart); + await publishPart(textEndPart); } if (!sawFinish) { const finishPart = { @@ -118,35 +92,31 @@ export async function sendBeeperStream( type: "finish", }; applyFinalMessagePart(accumulator, finishPart); - publishPart(finishPart); + await publishPart(finishPart); } - await waitForPublishes(); const finalAIMessage = opts.finalAIMessage ?? finalizeAccumulatedAIMessage(accumulator); const finalText = opts.finalText ?? getFinalMessageText(finalAIMessage); const finalContent = compactFinalContent({ aiMessage: finalAIMessage, body: finalText }); - const replacement = await client.messages.edit({ + const replacement = await client.beeper.streams.finalizeMessage({ + body: finalContent.body || "...", content: { body: finalContent.body || "...", "com.beeper.ai": finalContent.aiMessage, - "com.beeper.stream": null, msgtype: "m.text", }, eventId: target.eventId, - messageType: "m.text", roomId: opts.roomId, - text: finalContent.body || "...", topLevelContent: { "com.beeper.dont_render_edited": true, - "com.beeper.stream": null, }, }); return { - ...replacement, eventId: target.eventId, + roomId: replacement.roomId, raw: { logicalEventId: target.eventId, raw: replacement.raw, - replacementEventId: replacement.eventId, + replacementEventId: replacement.replacementEventId, }, }; } @@ -445,28 +415,3 @@ const NATIVE_STREAM_PART_TYPES = new Set([ "tool-output-denied", "tool-output-error", ]); - -async function publishBeeperStreamPart( - beeper: MatrixBeeper, - roomId: string, - eventId: string, - descriptor: Record, - turnId: string, - seq: number, - part: Record -): Promise { - const descriptorType = typeof descriptor.type === "string" ? descriptor.type : "com.beeper.llm"; - await beeper.streams.publish({ - content: { - [`${descriptorType}.deltas`]: [{ - "m.relates_to": { event_id: eventId, rel_type: "m.reference" }, - part, - seq, - target_event: eventId, - turn_id: turnId, - }], - }, - eventId, - roomId, - }); -} diff --git a/packages/pickle/src/types.ts b/packages/pickle/src/types.ts index 738ec88..7ce12da 100644 --- a/packages/pickle/src/types.ts +++ b/packages/pickle/src/types.ts @@ -31,10 +31,6 @@ export interface MatrixClientOptions { wasmUrl?: string | URL; } -export interface MatrixBeeperStreamDescriptor { - descriptor: Record; -} - export type MatrixStream = AsyncIterable>; export interface SendMatrixStreamOptions { @@ -48,24 +44,6 @@ export interface SendMatrixStreamOptions { updateIntervalMs?: number; } -export interface CreateBeeperStreamOptions { - roomId: string; - streamType?: string; -} - -export interface RegisterBeeperStreamOptions { - descriptor: Record; - eventId: string; - roomId: string; - subscribers?: BeeperStreamSubscriber[]; -} - -export interface PublishBeeperStreamOptions { - content?: Record; - eventId: string; - roomId: string; -} - export interface BeeperStreamSubscriber { deviceId: string; userId: string;