mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-25 00:16:29 +02:00
441 lines
No EOL
16 KiB
TypeScript
441 lines
No EOL
16 KiB
TypeScript
import '../lib/loadenv';
|
|
import { RecursiveCharacterTextSplitter } from "@langchain/textsplitters";
|
|
import { z } from 'zod';
|
|
import { dataSourceDocsCollection, dataSourcesCollection, projectsCollection, usersCollection } from '../lib/mongodb';
|
|
import { EmbeddingRecord, DataSourceDoc, DataSource } from "../lib/types/datasource_types";
|
|
import { ObjectId, WithId } from 'mongodb';
|
|
import { embedMany, generateText } from 'ai';
|
|
import { embeddingModel } from '../lib/embedding';
|
|
import { qdrantClient } from '../lib/qdrant';
|
|
import { PrefixLogger } from "../lib/utils";
|
|
import { GoogleGenerativeAI } from "@google/generative-ai";
|
|
import { GetObjectCommand } from "@aws-sdk/client-s3";
|
|
import { uploadsS3Client } from '../lib/uploads_s3_client';
|
|
import fs from 'fs/promises';
|
|
import crypto from 'crypto';
|
|
import path from 'path';
|
|
import { createOpenAI } from '@ai-sdk/openai';
|
|
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
|
|
import { authorize, BillingError, getCustomerIdForProject, logUsage } from '../lib/billing';
|
|
|
|
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
|
const FILE_PARSING_PROVIDER_BASE_URL = process.env.FILE_PARSING_PROVIDER_BASE_URL || undefined;
|
|
const FILE_PARSING_MODEL = process.env.FILE_PARSING_MODEL || 'gpt-4o';
|
|
|
|
const openai = createOpenAI({
|
|
apiKey: FILE_PARSING_PROVIDER_API_KEY,
|
|
baseURL: FILE_PARSING_PROVIDER_BASE_URL,
|
|
});
|
|
|
|
const UPLOADS_DIR = process.env.RAG_UPLOADS_DIR || '/uploads';
|
|
|
|
const splitter = new RecursiveCharacterTextSplitter({
|
|
separators: ['\n\n', '\n', '. ', '.', ''],
|
|
chunkSize: 1024,
|
|
chunkOverlap: 20,
|
|
});
|
|
|
|
const second = 1000;
|
|
const minute = 60 * second;
|
|
const hour = 60 * minute;
|
|
const day = 24 * hour;
|
|
|
|
// Configure Google Gemini API
|
|
const genAI = new GoogleGenerativeAI(process.env.GOOGLE_API_KEY || '');
|
|
|
|
async function getLocalFileContent(path: string): Promise<Buffer> {
|
|
return await fs.readFile(path);
|
|
}
|
|
|
|
async function getS3FileContent(s3Key: string): Promise<Buffer> {
|
|
const command = new GetObjectCommand({
|
|
Bucket: process.env.RAG_UPLOADS_S3_BUCKET,
|
|
Key: s3Key,
|
|
});
|
|
const response = await uploadsS3Client.send(command);
|
|
const chunks: Uint8Array[] = [];
|
|
for await (const chunk of response.Body as any) {
|
|
chunks.push(chunk);
|
|
}
|
|
return Buffer.concat(chunks);
|
|
}
|
|
|
|
async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Promise<T> {
|
|
let attempts = 0;
|
|
while (true) {
|
|
try {
|
|
return await fn();
|
|
} catch (e) {
|
|
attempts++;
|
|
if (attempts >= maxAttempts) {
|
|
throw e;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }): Promise<number> {
|
|
const logger = _logger
|
|
.child(doc._id.toString())
|
|
.child(doc.name);
|
|
|
|
// Get file content
|
|
let fileData: Buffer;
|
|
if (doc.data.type === 'file_local') {
|
|
logger.log("Fetching file from local");
|
|
fileData = await getLocalFileContent(path.join(UPLOADS_DIR, doc._id.toString()));
|
|
} else {
|
|
logger.log("Fetching file from S3");
|
|
fileData = await getS3FileContent(doc.data.s3Key);
|
|
}
|
|
|
|
let markdown = "";
|
|
const extractPrompt = "Extract and return only the text content from this document in markdown format. Exclude any formatting instructions or additional commentary.";
|
|
if (!USE_GEMINI_FILE_PARSING) {
|
|
// Use OpenAI to extract text content
|
|
logger.log("Extracting content using OpenAI");
|
|
const { text } = await generateText({
|
|
model: openai(FILE_PARSING_MODEL),
|
|
system: extractPrompt,
|
|
messages: [
|
|
{
|
|
role: "user",
|
|
content: [
|
|
{
|
|
type: "file",
|
|
data: fileData.toString('base64'),
|
|
mimeType: doc.data.mimeType,
|
|
}
|
|
]
|
|
}
|
|
],
|
|
});
|
|
markdown = text;
|
|
} else {
|
|
// Use Gemini to extract text content
|
|
logger.log("Extracting content using Gemini");
|
|
const model = genAI.getGenerativeModel({ model: "gemini-2.0-flash-001" });
|
|
|
|
const result = await model.generateContent([
|
|
{
|
|
inlineData: {
|
|
data: fileData.toString('base64'),
|
|
mimeType: doc.data.mimeType
|
|
}
|
|
},
|
|
extractPrompt,
|
|
]);
|
|
markdown = result.response.text();
|
|
}
|
|
|
|
// split into chunks
|
|
logger.log("Splitting into chunks");
|
|
const splits = await splitter.createDocuments([markdown]);
|
|
|
|
// generate embeddings
|
|
logger.log("Generating embeddings");
|
|
const { embeddings, usage } = await embedMany({
|
|
model: embeddingModel,
|
|
values: splits.map((split) => split.pageContent)
|
|
});
|
|
|
|
// store embeddings in qdrant
|
|
logger.log("Storing embeddings in Qdrant");
|
|
const points: z.infer<typeof EmbeddingRecord>[] = embeddings.map((embedding, i) => ({
|
|
id: crypto.randomUUID(),
|
|
vector: embedding,
|
|
payload: {
|
|
projectId: job.projectId,
|
|
sourceId: job._id.toString(),
|
|
docId: doc._id.toString(),
|
|
content: splits[i].pageContent,
|
|
title: doc.name,
|
|
name: doc.name,
|
|
},
|
|
}));
|
|
await qdrantClient.upsert("embeddings", {
|
|
points,
|
|
});
|
|
|
|
// store content in doc record
|
|
logger.log("Storing content in doc record");
|
|
await dataSourceDocsCollection.updateOne({
|
|
_id: doc._id,
|
|
version: doc.version,
|
|
}, {
|
|
$set: {
|
|
content: markdown,
|
|
status: "ready",
|
|
lastUpdatedAt: new Date().toISOString(),
|
|
}
|
|
});
|
|
|
|
return usage.tokens;
|
|
}
|
|
|
|
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
|
const logger = _logger
|
|
.child(doc._id.toString())
|
|
.child(doc.name);
|
|
|
|
// Delete embeddings from qdrant
|
|
logger.log("Deleting embeddings from Qdrant");
|
|
await qdrantClient.delete("embeddings", {
|
|
filter: {
|
|
must: [
|
|
{
|
|
key: "projectId",
|
|
match: {
|
|
value: job.projectId,
|
|
}
|
|
},
|
|
{
|
|
key: "sourceId",
|
|
match: {
|
|
value: job._id.toString(),
|
|
}
|
|
},
|
|
{
|
|
key: "docId",
|
|
match: {
|
|
value: doc._id.toString(),
|
|
}
|
|
}
|
|
],
|
|
},
|
|
});
|
|
|
|
// Delete docs from db
|
|
logger.log("Deleting doc from db");
|
|
await dataSourceDocsCollection.deleteOne({ _id: doc._id });
|
|
}
|
|
|
|
// fetch next job from mongodb
|
|
(async () => {
|
|
while (true) {
|
|
const now = Date.now();
|
|
let job: WithId<z.infer<typeof DataSource>> | null = null;
|
|
|
|
// first try to find a job that needs deleting
|
|
job = await dataSourcesCollection.findOneAndUpdate({
|
|
status: "deleted",
|
|
"data.type": { $in: ["files_local", "files_s3"] },
|
|
$or: [
|
|
{ attempts: { $exists: false } },
|
|
{ attempts: { $lte: 3 } }
|
|
]
|
|
}, { $set: { lastAttemptAt: new Date().toISOString() }, $inc: { attempts: 1 } }, { returnDocument: "after", sort: { createdAt: 1 } });
|
|
|
|
if (job === null) {
|
|
|
|
job = await dataSourcesCollection.findOneAndUpdate(
|
|
{
|
|
$and: [
|
|
{ 'data.type': { $in: ["files_local", "files_s3"] } },
|
|
{
|
|
$or: [
|
|
// if the job has never been attempted
|
|
{
|
|
status: "pending",
|
|
attempts: 0,
|
|
},
|
|
// if the job was attempted but wasn't completed in the last hour
|
|
{
|
|
status: "pending",
|
|
lastAttemptAt: { $lt: new Date(now - 1 * hour).toISOString() },
|
|
},
|
|
// if the job errored out but hasn't been retried 3 times yet
|
|
{
|
|
status: "error",
|
|
attempts: { $lt: 3 },
|
|
},
|
|
// if the job errored out but hasn't been retried in the last 5 minutes
|
|
{
|
|
status: "error",
|
|
lastAttemptAt: { $lt: new Date(now - 1 * hour).toISOString() },
|
|
},
|
|
]
|
|
}
|
|
]
|
|
},
|
|
{
|
|
$set: {
|
|
status: "pending",
|
|
lastAttemptAt: new Date().toISOString(),
|
|
},
|
|
$inc: {
|
|
attempts: 1
|
|
},
|
|
},
|
|
{ returnDocument: "after", sort: { createdAt: 1 } }
|
|
);
|
|
}
|
|
|
|
if (job === null) {
|
|
// if no doc found, sleep for a bit and start again
|
|
await new Promise(resolve => setTimeout(resolve, 5 * second));
|
|
continue;
|
|
}
|
|
|
|
const logger = new PrefixLogger(`${job._id.toString()}-${job.version}`);
|
|
logger.log(`Starting job ${job._id}. Type: ${job.data.type}. Status: ${job.status}`);
|
|
let errors = false;
|
|
|
|
try {
|
|
if (job.data.type !== 'files_local' && job.data.type !== 'files_s3') {
|
|
throw new Error("Invalid data source type");
|
|
}
|
|
|
|
if (job.status === "deleted") {
|
|
// delete all embeddings for this source
|
|
logger.log("Deleting embeddings from Qdrant");
|
|
await qdrantClient.delete("embeddings", {
|
|
filter: {
|
|
must: [
|
|
{ key: "projectId", match: { value: job.projectId } },
|
|
{ key: "sourceId", match: { value: job._id.toString() } },
|
|
],
|
|
},
|
|
});
|
|
|
|
// delete all docs for this source
|
|
logger.log("Deleting docs from db");
|
|
await dataSourceDocsCollection.deleteMany({
|
|
sourceId: job._id.toString(),
|
|
});
|
|
|
|
// delete the source record from db
|
|
logger.log("Deleting source record from db");
|
|
await dataSourcesCollection.deleteOne({
|
|
_id: job._id,
|
|
});
|
|
|
|
logger.log("Job deleted");
|
|
continue;
|
|
}
|
|
|
|
// fetch docs that need updating
|
|
const pendingDocs = await dataSourceDocsCollection.find({
|
|
sourceId: job._id.toString(),
|
|
status: { $in: ["pending", "error"] },
|
|
}).toArray();
|
|
|
|
logger.log(`Found ${pendingDocs.length} docs to process`);
|
|
|
|
// fetch project, user and billing data
|
|
let billingCustomerId: string | null = null;
|
|
if (USE_BILLING) {
|
|
try {
|
|
billingCustomerId = await getCustomerIdForProject(job.projectId);
|
|
} catch (e) {
|
|
logger.log("Unable to fetch billing customer id:", e);
|
|
throw new Error("Unable to fetch billing customer id");
|
|
}
|
|
}
|
|
|
|
// for each doc
|
|
for (const doc of pendingDocs) {
|
|
// authorize with billing
|
|
if (USE_BILLING && billingCustomerId) {
|
|
const authResponse = await authorize(billingCustomerId, {
|
|
type: "process_rag",
|
|
data: {},
|
|
});
|
|
|
|
if ('error' in authResponse) {
|
|
throw new BillingError(authResponse.error || "Unknown billing error")
|
|
}
|
|
}
|
|
|
|
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
|
|
try {
|
|
const usedTokens = await runProcessPipeline(logger, job, ldoc);
|
|
|
|
// log usage in billing
|
|
if (USE_BILLING && billingCustomerId) {
|
|
await logUsage(billingCustomerId, {
|
|
type: "rag_tokens",
|
|
amount: usedTokens,
|
|
});
|
|
}
|
|
} catch (e: any) {
|
|
errors = true;
|
|
logger.log("Error processing doc:", e);
|
|
await dataSourceDocsCollection.updateOne({
|
|
_id: doc._id,
|
|
version: doc.version,
|
|
}, {
|
|
$set: {
|
|
status: "error",
|
|
error: e.message,
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
// fetch docs that need to be deleted
|
|
const deletedDocs = await dataSourceDocsCollection.find({
|
|
sourceId: job._id.toString(),
|
|
status: "deleted",
|
|
}).toArray();
|
|
|
|
logger.log(`Found ${deletedDocs.length} docs to delete`);
|
|
|
|
for (const doc of deletedDocs) {
|
|
try {
|
|
await runDeletionPipeline(logger, job, doc);
|
|
} catch (e: any) {
|
|
errors = true;
|
|
logger.log("Error deleting doc:", e);
|
|
await dataSourceDocsCollection.updateOne({
|
|
_id: doc._id,
|
|
version: doc.version,
|
|
}, {
|
|
$set: {
|
|
status: "error",
|
|
error: e.message,
|
|
}
|
|
});
|
|
}
|
|
}
|
|
} catch (e) {
|
|
if (e instanceof BillingError) {
|
|
logger.log("Billing error:", e.message);
|
|
await dataSourcesCollection.updateOne({
|
|
_id: job._id,
|
|
version: job.version,
|
|
}, {
|
|
$set: {
|
|
status: "error",
|
|
billingError: e.message,
|
|
lastUpdatedAt: new Date().toISOString(),
|
|
}
|
|
});
|
|
}
|
|
logger.log("Error processing job; will retry:", e);
|
|
await dataSourcesCollection.updateOne({
|
|
_id: job._id,
|
|
version: job.version,
|
|
}, {
|
|
$set: {
|
|
status: "error",
|
|
lastUpdatedAt: new Date().toISOString(),
|
|
}
|
|
});
|
|
continue;
|
|
}
|
|
|
|
// mark job as complete
|
|
logger.log("Marking job as completed...");
|
|
await dataSourcesCollection.updateOne({
|
|
_id: job._id,
|
|
version: job.version,
|
|
}, {
|
|
$set: {
|
|
status: errors ? "error" : "ready",
|
|
...(errors ? { error: "There were some errors processing this job" } : {}),
|
|
lastUpdatedAt: new Date().toISOString(),
|
|
}
|
|
});
|
|
}
|
|
})(); |