add rate-limiting

This commit is contained in:
ramnique 2025-02-04 16:35:12 +05:30
parent 024f6c75cc
commit 200e8d2e38
9 changed files with 188 additions and 2 deletions

View file

@ -3,7 +3,7 @@
import { redirect } from "next/navigation";
import { SimulationData, EmbeddingDoc, GetInformationToolResult, DataSource, PlaygroundChat, AgenticAPIChatRequest, AgenticAPIChatResponse, convertFromAgenticAPIChatMessages, WebpageCrawlResponse, Workflow, WorkflowAgent, CopilotAPIRequest, CopilotAPIResponse, CopilotMessage, CopilotWorkflow, convertToCopilotWorkflow, convertToCopilotApiMessage, convertToCopilotMessage, CopilotAssistantMessage, CopilotChatContext, convertToCopilotApiChatContext, Scenario, ClientToolCallRequestBody, ClientToolCallJwt, ClientToolCallRequest, WithStringId, Project, WorkflowTool, WorkflowPrompt, ApiKey } from "./lib/types";
import { ObjectId, WithId } from "mongodb";
import { generateObject, generateText, tool, embed } from "ai";
import { generateObject, generateText, embed } from "ai";
import { dataSourcesCollection, embeddingsCollection, projectsCollection, webpagesCollection, agentWorkflowsCollection, scenariosCollection, projectMembersCollection, apiKeysCollection } from "@/app/lib/mongodb";
import { z } from 'zod';
import { openai } from "@ai-sdk/openai";
@ -12,12 +12,13 @@ import { embeddingModel } from "./lib/embedding";
import { apiV1 } from "rowboat-shared";
import { zodToJsonSchema } from 'zod-to-json-schema';
import crypto from 'crypto';
import { SignJWT } from "jose";
import { Claims, getSession } from "@auth0/nextjs-auth0";
import { revalidatePath } from "next/cache";
import { callClientToolWebhook, getAgenticApiResponse } from "./lib/utils";
import { templates } from "./lib/project_templates";
import { assert, error } from "node:console";
import { check_query_limit } from "./lib/rate_limiting";
import { QueryLimitError } from "./lib/client_utils";
const crawler = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY || '' });
@ -319,6 +320,19 @@ export async function scrapeWebpage(url: string): Promise<z.infer<typeof Webpage
export async function createProject(formData: FormData) {
const user = await authCheck();
// ensure that projects created by this user is less than
// configured limit
const projectsLimit = Number(process.env.MAX_PROJECTS_PER_USER) || 0;
if (projectsLimit > 0) {
const count = await projectsCollection.countDocuments({
createdByUserId: user.sub,
});
if (count >= projectsLimit) {
throw new Error('You have reached your project limit. Please upgrade your plan.');
}
}
const name = formData.get('name') as string;
const templateKey = formData.get('template') as string;
const projectId = crypto.randomUUID();
@ -492,6 +506,9 @@ export async function getAssistantResponse(
rawResponse: unknown,
}> {
await projectAuthCheck(projectId);
if (!await check_query_limit(projectId)) {
throw new QueryLimitError();
}
const response = await getAgenticApiResponse(request);
return {
@ -513,6 +530,9 @@ export async function getCopilotResponse(
rawResponse: unknown,
}> {
await projectAuthCheck(projectId);
if (!await check_query_limit(projectId)) {
throw new QueryLimitError();
}
// prepare request
const request: z.infer<typeof CopilotAPIRequest> = {
@ -643,6 +663,9 @@ export async function getCopilotResponse(
export async function suggestToolResponse(toolId: string, projectId: string, messages: z.infer<typeof apiV1.ChatMessage>[]): Promise<string> {
await projectAuthCheck(projectId);
if (!await check_query_limit(projectId)) {
throw new QueryLimitError();
}
const prompt = `
# Your Specific Task:
@ -891,6 +914,10 @@ export async function simulateUserResponse(
simulationData: z.infer<typeof SimulationData>
): Promise<string> {
await projectAuthCheck(projectId);
if (!await check_query_limit(projectId)) {
throw new QueryLimitError();
}
const articlePrompt = `
# Your Specific Task:

View file

@ -5,6 +5,7 @@ import { ObjectId } from "mongodb";
import { authCheck } from "@/app/api/v1/utils";
import { convertFromApiToAgenticApiMessages, convertFromAgenticApiToApiMessages, AgenticAPIChatRequest, ApiRequest, ApiResponse, convertWorkflowToAgenticAPI } from "@/app/lib/types";
import { getAgenticApiResponse } from "@/app/lib/utils";
import { check_query_limit } from "@/app/lib/rate_limiting";
// get next turn / agent response
export async function POST(
@ -13,6 +14,11 @@ export async function POST(
): Promise<Response> {
const { projectId } = await params;
// check query limit
if (!await check_query_limit(projectId)) {
return Response.json({ error: "Query limit exceeded" }, { status: 429 });
}
return await authCheck(projectId, req, async () => {
// parse and validate the request body
let body;

View file

@ -6,6 +6,7 @@ import { ObjectId, WithId } from "mongodb";
import { authCheck } from "../../../utils";
import { AgenticAPIChatRequest, convertFromAgenticAPIChatMessages, convertToAgenticAPIChatMessages, convertWorkflowToAgenticAPI } from "@/app/lib/types";
import { callClientToolWebhook, getAgenticApiResponse } from "@/app/lib/utils";
import { check_query_limit } from "@/app/lib/rate_limiting";
const chatsCollection = db.collection<z.infer<typeof apiV1.Chat>>("chats");
const chatMessagesCollection = db.collection<z.infer<typeof apiV1.ChatMessage>>("chatMessages");
@ -18,6 +19,11 @@ export async function POST(
return await authCheck(req, async (session) => {
const { chatId } = await params;
// check query limit
if (!await check_query_limit(session.projectId)) {
return Response.json({ error: "Query limit exceeded" }, { status: 429 });
}
// parse and validate the request body
let body;
try {

View file

@ -0,0 +1,6 @@
export class QueryLimitError extends Error {
constructor(message: string = 'Query limit exceeded') {
super(message);
this.name = 'QueryLimitError';
}
}

View file

@ -0,0 +1,21 @@
import { redisClient } from "./redis";
const MAX_QUERIES_PER_MINUTE = Number(process.env.MAX_QUERIES_PER_MINUTE) || 0;
export async function check_query_limit(projectId: string): Promise<boolean> {
// if the limit is 0, we don't want to check the limit
if (MAX_QUERIES_PER_MINUTE === 0) {
return true;
}
const minutes_since_epoch = Math.floor(Date.now() / 1000 / 60); // 60 second window
const key = `rate_limit:${projectId}:${minutes_since_epoch}`;
// increment the counter and return the count
const count = await redisClient.incr(key);
if (count === 1) {
await redisClient.expire(key, 70); // Set TTL to clean up automatically
}
return count <= MAX_QUERIES_PER_MINUTE;
}

View file

@ -0,0 +1,7 @@
import { createClient } from 'redis';
export const redisClient = createClient({
url: process.env.REDIS_URL,
});
redisClient.connect();