Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"@types/d3": "7.4.3",
"@types/dompurify": "3.0.5",
"@types/lodash": "4.17.7",
"@types/node": "22.19.15",
"@types/node": "22.19.17",
"@types/nodemailer": "6.4.15",
"@types/pg": "8.11.6",
"@types/pg-copy-streams": "1.2.5",
Expand Down
52 changes: 46 additions & 6 deletions src/db.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { initializeDatabase } from "@/tests/vitest/dbUtils";
import { beforeEach, describe, expect, test } from "vitest";
import { copyStreamV2, query } from "./db";
import { Readable } from "stream";
import { copyStream, query } from "./db";
import { PassThrough, Readable } from "stream";

initializeDatabase();

describe("copyStreamV2", () => {
describe("copyStream", () => {
const sourceData: Array<{
id: number;
flag: boolean;
Expand Down Expand Up @@ -39,7 +39,7 @@ describe("copyStreamV2", () => {
});

test("can stream data into a table", async () => {
await copyStreamV2<any, any>({
await copyStream<any, any>({
table: "test",
stream: Readable.from(sourceData),
fields: {
Expand All @@ -64,7 +64,7 @@ describe("copyStreamV2", () => {
});

test("can stream a subset of columns into a table", async () => {
await copyStreamV2<any, any>({
await copyStream<any, any>({
table: "test",
stream: Readable.from(sourceData),
fields: {
Expand All @@ -91,7 +91,7 @@ describe("copyStreamV2", () => {
chunks.push(sourceData.slice(10 * i, 10 * (i + 1)));
}

await copyStreamV2<any, any>({
await copyStream<any, any>({
table: "test",
stream: Readable.from(chunks),
fields: {
Expand All @@ -114,4 +114,44 @@ describe("copyStreamV2", () => {
})),
);
});

test("errors in copy stream are handled", async () => {
const result = copyStream<any, any>({
table: "missing_table",
fields: {
id: (record: any) => record.id.toString(),
flag: (record: any) => record.flag.toString(),
content: (record: any) => record.content,
total: (record: any) => record.total.toString(),
created_at: (record: any) => record.createdAt.toISOString(),
},
stream: Readable.from(sourceData),
});

await expect(result).rejects.toThrowError();
});

test("errors in source stream are handled", async () => {
async function* testData() {
yield sourceData[0];
throw new Error("test error");
}

const result = copyStream<any, any>({
table: "missing_table",
fields: {
id: (record: any) => record.id.toString(),
flag: (record: any) => record.flag.toString(),
content: (record: any) => record.content,
total: (record: any) => record.total.toString(),
created_at: (record: any) => record.createdAt.toISOString(),
},
// This extra pass through ensures that we cover the case where the error is a stream not directly passed to the db function.
stream: Readable.from(testData()).compose(
new PassThrough({ objectMode: true }),
),
});

await expect(result).rejects.toThrowError(new Error("test error"));
});
});
16 changes: 1 addition & 15 deletions src/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ export async function queryStream(
return stream;
}

export async function copyStream(
table: string,
stream: Readable,
): Promise<void> {
const client = await getPool().connect();

try {
const dbStream = client.query(copyFrom(`copy ${table} from stdin`));
await pipeline(stream, dbStream);
} finally {
client.release();
}
}

export async function transaction<T>(
tx: (q: typeof query) => Promise<T>,
): Promise<T> {
Expand Down Expand Up @@ -196,7 +182,7 @@ export async function reconnect() {
_pool = undefined;
}

export async function copyStreamV2<
export async function copyStream<
Record = unknown,
Table extends keyof Database = keyof Database,
>({
Expand Down
115 changes: 113 additions & 2 deletions src/modules/translation/data-access/machineGlossRepository.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { initializeDatabase } from "@/tests/vitest/dbUtils";
import { getDb } from "@/db";
import { beforeEach, describe, expect, test } from "vitest";
import { beforeEach, describe, expect, test, vi } from "vitest";
import { machineGlossRepository } from "./machineGlossRepository";
import { Readable } from "stream";
import { languageFactory } from "@/modules/languages/test-utils/languageFactory";
import type { AIGlossChapter } from "./aiGlossImportService";

initializeDatabase();

Expand Down Expand Up @@ -115,11 +116,23 @@ describe("updateAllForLanguage", () => {
gloss: "Should be dropped",
},
];
const chapters: Array<AIGlossChapter> = [
{
bookId: 1,
chapterNumber: 1,
glosses: newGlosses.slice(0, 2),
},
{
bookId: 1,
chapterNumber: 1,
glosses: newGlosses.slice(2),
},
];

await machineGlossRepository.updateAllForLanguage({
languageId: spaLanguage.id,
modelCode: "llm_import",
stream: Readable.from(newGlosses),
stream: Readable.from(chapters),
});

const insertedGlosses = await getDb()
Expand All @@ -138,4 +151,102 @@ describe("updateAllForLanguage", () => {
})),
]);
});

test("tracks progress when streaming AI gloss chapters", async () => {
const { language } = await languageFactory.build({
members: [],
});

const onProgress = vi.fn().mockResolvedValue(undefined);
const chapterStream: Array<AIGlossChapter> = [
{
bookId: 1,
chapterNumber: 1,
glosses: [{ wordId: "0100100101", gloss: "One" }],
},
{
bookId: 1,
chapterNumber: 2,
glosses: [{ wordId: "0100100102", gloss: "Two" }],
},
{
bookId: 2,
chapterNumber: 1,
glosses: [{ wordId: "0100100103", gloss: "Three" }],
},
];

await machineGlossRepository.updateAllForLanguage({
languageId: language.id,
modelCode: "llm_import",
stream: Readable.from(chapterStream),
onProgress,
});

expect(onProgress).toHaveBeenCalledTimes(2);
expect(onProgress).toHaveBeenNthCalledWith(1, 1);
expect(onProgress).toHaveBeenNthCalledWith(2, 2);

const insertedGlosses = await getDb()
.selectFrom("machine_gloss")
.where("language_id", "=", language.id)
.orderBy("id")
.select(["word_id", "gloss"])
.execute();

expect(insertedGlosses).toEqual([
{ word_id: "0100100101", gloss: "One" },
{ word_id: "0100100102", gloss: "Two" },
{ word_id: "0100100103", gloss: "Three" },
]);
});

test("errors in tracks progress don't crash the stream", async () => {
const { language } = await languageFactory.build({
members: [],
});

const onProgress = vi.fn().mockRejectedValue(new Error("test error"));
const chapterStream: Array<AIGlossChapter> = [
{
bookId: 1,
chapterNumber: 1,
glosses: [{ wordId: "0100100101", gloss: "One" }],
},
{
bookId: 1,
chapterNumber: 2,
glosses: [{ wordId: "0100100102", gloss: "Two" }],
},
{
bookId: 2,
chapterNumber: 1,
glosses: [{ wordId: "0100100103", gloss: "Three" }],
},
];

await machineGlossRepository.updateAllForLanguage({
languageId: language.id,
modelCode: "llm_import",
stream: Readable.from(chapterStream),
onProgress,
});

expect(onProgress).toHaveBeenCalledTimes(2);
expect(onProgress).toHaveBeenNthCalledWith(1, 1);
expect(onProgress).toHaveBeenNthCalledWith(2, 2);

const insertedGlosses = await getDb()
.selectFrom("machine_gloss")
.where("language_id", "=", language.id)
.orderBy("id")
.select(["word_id", "gloss"])
.execute();

expect(insertedGlosses).toEqual([
{ word_id: "0100100101", gloss: "One" },
{ word_id: "0100100102", gloss: "Two" },
{ word_id: "0100100103", gloss: "Three" },
]);
});
});
40 changes: 37 additions & 3 deletions src/modules/translation/data-access/machineGlossRepository.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { copyStreamV2, getDb } from "@/db";
import { copyStream, getDb } from "@/db";
import type { AIGloss, AIGlossChapter } from "./aiGlossImportService";
import { Readable, Transform } from "stream";

export interface StreamedMachineGloss {
Expand All @@ -11,10 +12,12 @@ export const machineGlossRepository = {
languageId,
modelCode,
stream,
onProgress,
}: {
languageId: string;
modelCode: string;
stream: Readable;
onProgress?: (bookId: number) => Promise<void>;
}): Promise<void> {
const words = await getDb().selectFrom("word").select("id").execute();
const wordIdSet = buildWordIdsSet(words);
Expand All @@ -30,19 +33,50 @@ export const machineGlossRepository = {
.where("language_id", "=", languageId)
.execute();

await copyStreamV2<StreamedMachineGloss, "machine_gloss">({
const progressTransform = new TrackBookProgressTransform(onProgress);
const filterTransform = new FilterMissingWordsTransform(wordIdSet);

await copyStream<StreamedMachineGloss, "machine_gloss">({
table: "machine_gloss",
stream: stream.pipe(new FilterMissingWordsTransform(wordIdSet)),
fields: {
word_id: (record) => record.wordId,
language_id: () => languageId,
model_id: () => model.id.toString(),
gloss: (record) => record.gloss,
},
stream: stream.compose(progressTransform).compose(filterTransform),
});
},
};

class TrackBookProgressTransform extends Transform {
private currentBookId: number | undefined;

constructor(private onBookIdChange?: (bookId: number) => Promise<void>) {
super({ writableObjectMode: true, readableObjectMode: true });
}

override _transform(
chapter: AIGlossChapter,
_encoding: BufferEncoding,
cb: (error?: Error | null, data?: Array<AIGloss>) => void,
) {
if (chapter.bookId !== this.currentBookId) {
this.currentBookId = chapter.bookId;
// This is intentionally not awaited since we don't want to block the stream
if (this.onBookIdChange) {
this.onBookIdChange(chapter.bookId).catch((err) => {
console.error(
`Unhandled failure in TrackBookProgressTransform.onBookIdChange: ${err}`,
);
});
}
}

cb(null, chapter.glosses);
}
}

export class FilterMissingWordsTransform extends Transform {
constructor(private readonly existingWordIds: ReadonlySet<number>) {
super({
Expand Down
Loading
Loading