mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-27 09:26:23 +02:00
add rate-limiting
This commit is contained in:
parent
024f6c75cc
commit
200e8d2e38
9 changed files with 188 additions and 2 deletions
|
|
@ -3,7 +3,7 @@
|
|||
import { redirect } from "next/navigation";
|
||||
import { SimulationData, EmbeddingDoc, GetInformationToolResult, DataSource, PlaygroundChat, AgenticAPIChatRequest, AgenticAPIChatResponse, convertFromAgenticAPIChatMessages, WebpageCrawlResponse, Workflow, WorkflowAgent, CopilotAPIRequest, CopilotAPIResponse, CopilotMessage, CopilotWorkflow, convertToCopilotWorkflow, convertToCopilotApiMessage, convertToCopilotMessage, CopilotAssistantMessage, CopilotChatContext, convertToCopilotApiChatContext, Scenario, ClientToolCallRequestBody, ClientToolCallJwt, ClientToolCallRequest, WithStringId, Project, WorkflowTool, WorkflowPrompt, ApiKey } from "./lib/types";
|
||||
import { ObjectId, WithId } from "mongodb";
|
||||
import { generateObject, generateText, tool, embed } from "ai";
|
||||
import { generateObject, generateText, embed } from "ai";
|
||||
import { dataSourcesCollection, embeddingsCollection, projectsCollection, webpagesCollection, agentWorkflowsCollection, scenariosCollection, projectMembersCollection, apiKeysCollection } from "@/app/lib/mongodb";
|
||||
import { z } from 'zod';
|
||||
import { openai } from "@ai-sdk/openai";
|
||||
|
|
@ -12,12 +12,13 @@ import { embeddingModel } from "./lib/embedding";
|
|||
import { apiV1 } from "rowboat-shared";
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import crypto from 'crypto';
|
||||
import { SignJWT } from "jose";
|
||||
import { Claims, getSession } from "@auth0/nextjs-auth0";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { callClientToolWebhook, getAgenticApiResponse } from "./lib/utils";
|
||||
import { templates } from "./lib/project_templates";
|
||||
import { assert, error } from "node:console";
|
||||
import { check_query_limit } from "./lib/rate_limiting";
|
||||
import { QueryLimitError } from "./lib/client_utils";
|
||||
|
||||
const crawler = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY || '' });
|
||||
|
||||
|
|
@ -319,6 +320,19 @@ export async function scrapeWebpage(url: string): Promise<z.infer<typeof Webpage
|
|||
|
||||
export async function createProject(formData: FormData) {
|
||||
const user = await authCheck();
|
||||
|
||||
// ensure that projects created by this user is less than
|
||||
// configured limit
|
||||
const projectsLimit = Number(process.env.MAX_PROJECTS_PER_USER) || 0;
|
||||
if (projectsLimit > 0) {
|
||||
const count = await projectsCollection.countDocuments({
|
||||
createdByUserId: user.sub,
|
||||
});
|
||||
if (count >= projectsLimit) {
|
||||
throw new Error('You have reached your project limit. Please upgrade your plan.');
|
||||
}
|
||||
}
|
||||
|
||||
const name = formData.get('name') as string;
|
||||
const templateKey = formData.get('template') as string;
|
||||
const projectId = crypto.randomUUID();
|
||||
|
|
@ -492,6 +506,9 @@ export async function getAssistantResponse(
|
|||
rawResponse: unknown,
|
||||
}> {
|
||||
await projectAuthCheck(projectId);
|
||||
if (!await check_query_limit(projectId)) {
|
||||
throw new QueryLimitError();
|
||||
}
|
||||
|
||||
const response = await getAgenticApiResponse(request);
|
||||
return {
|
||||
|
|
@ -513,6 +530,9 @@ export async function getCopilotResponse(
|
|||
rawResponse: unknown,
|
||||
}> {
|
||||
await projectAuthCheck(projectId);
|
||||
if (!await check_query_limit(projectId)) {
|
||||
throw new QueryLimitError();
|
||||
}
|
||||
|
||||
// prepare request
|
||||
const request: z.infer<typeof CopilotAPIRequest> = {
|
||||
|
|
@ -643,6 +663,9 @@ export async function getCopilotResponse(
|
|||
|
||||
export async function suggestToolResponse(toolId: string, projectId: string, messages: z.infer<typeof apiV1.ChatMessage>[]): Promise<string> {
|
||||
await projectAuthCheck(projectId);
|
||||
if (!await check_query_limit(projectId)) {
|
||||
throw new QueryLimitError();
|
||||
}
|
||||
|
||||
const prompt = `
|
||||
# Your Specific Task:
|
||||
|
|
@ -891,6 +914,10 @@ export async function simulateUserResponse(
|
|||
simulationData: z.infer<typeof SimulationData>
|
||||
): Promise<string> {
|
||||
await projectAuthCheck(projectId);
|
||||
if (!await check_query_limit(projectId)) {
|
||||
throw new QueryLimitError();
|
||||
}
|
||||
|
||||
const articlePrompt = `
|
||||
# Your Specific Task:
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { ObjectId } from "mongodb";
|
|||
import { authCheck } from "@/app/api/v1/utils";
|
||||
import { convertFromApiToAgenticApiMessages, convertFromAgenticApiToApiMessages, AgenticAPIChatRequest, ApiRequest, ApiResponse, convertWorkflowToAgenticAPI } from "@/app/lib/types";
|
||||
import { getAgenticApiResponse } from "@/app/lib/utils";
|
||||
import { check_query_limit } from "@/app/lib/rate_limiting";
|
||||
|
||||
// get next turn / agent response
|
||||
export async function POST(
|
||||
|
|
@ -13,6 +14,11 @@ export async function POST(
|
|||
): Promise<Response> {
|
||||
const { projectId } = await params;
|
||||
|
||||
// check query limit
|
||||
if (!await check_query_limit(projectId)) {
|
||||
return Response.json({ error: "Query limit exceeded" }, { status: 429 });
|
||||
}
|
||||
|
||||
return await authCheck(projectId, req, async () => {
|
||||
// parse and validate the request body
|
||||
let body;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import { ObjectId, WithId } from "mongodb";
|
|||
import { authCheck } from "../../../utils";
|
||||
import { AgenticAPIChatRequest, convertFromAgenticAPIChatMessages, convertToAgenticAPIChatMessages, convertWorkflowToAgenticAPI } from "@/app/lib/types";
|
||||
import { callClientToolWebhook, getAgenticApiResponse } from "@/app/lib/utils";
|
||||
import { check_query_limit } from "@/app/lib/rate_limiting";
|
||||
|
||||
const chatsCollection = db.collection<z.infer<typeof apiV1.Chat>>("chats");
|
||||
const chatMessagesCollection = db.collection<z.infer<typeof apiV1.ChatMessage>>("chatMessages");
|
||||
|
|
@ -18,6 +19,11 @@ export async function POST(
|
|||
return await authCheck(req, async (session) => {
|
||||
const { chatId } = await params;
|
||||
|
||||
// check query limit
|
||||
if (!await check_query_limit(session.projectId)) {
|
||||
return Response.json({ error: "Query limit exceeded" }, { status: 429 });
|
||||
}
|
||||
|
||||
// parse and validate the request body
|
||||
let body;
|
||||
try {
|
||||
|
|
|
|||
6
apps/rowboat/app/lib/client_utils.ts
Normal file
6
apps/rowboat/app/lib/client_utils.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
export class QueryLimitError extends Error {
|
||||
constructor(message: string = 'Query limit exceeded') {
|
||||
super(message);
|
||||
this.name = 'QueryLimitError';
|
||||
}
|
||||
}
|
||||
21
apps/rowboat/app/lib/rate_limiting.ts
Normal file
21
apps/rowboat/app/lib/rate_limiting.ts
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import { redisClient } from "./redis";
|
||||
|
||||
const MAX_QUERIES_PER_MINUTE = Number(process.env.MAX_QUERIES_PER_MINUTE) || 0;
|
||||
|
||||
export async function check_query_limit(projectId: string): Promise<boolean> {
|
||||
// if the limit is 0, we don't want to check the limit
|
||||
if (MAX_QUERIES_PER_MINUTE === 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const minutes_since_epoch = Math.floor(Date.now() / 1000 / 60); // 60 second window
|
||||
const key = `rate_limit:${projectId}:${minutes_since_epoch}`;
|
||||
|
||||
// increment the counter and return the count
|
||||
const count = await redisClient.incr(key);
|
||||
if (count === 1) {
|
||||
await redisClient.expire(key, 70); // Set TTL to clean up automatically
|
||||
}
|
||||
|
||||
return count <= MAX_QUERIES_PER_MINUTE;
|
||||
}
|
||||
7
apps/rowboat/app/lib/redis.ts
Normal file
7
apps/rowboat/app/lib/redis.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import { createClient } from 'redis';
|
||||
|
||||
export const redisClient = createClient({
|
||||
url: process.env.REDIS_URL,
|
||||
});
|
||||
|
||||
redisClient.connect();
|
||||
Loading…
Add table
Add a link
Reference in a new issue