diff --git a/apps/rowboat/app/actions/copilot_actions.ts b/apps/rowboat/app/actions/copilot_actions.ts index 91a9c28c..bf70ddee 100644 --- a/apps/rowboat/app/actions/copilot_actions.ts +++ b/apps/rowboat/app/actions/copilot_actions.ts @@ -7,14 +7,16 @@ import { Workflow} from "../lib/types/workflow_types"; import { DataSource } from "../lib/types/datasource_types"; import { z } from 'zod'; -import { check_query_limit } from "../lib/rate_limiting"; -import { QueryLimitError } from "@/src/entities/errors/common"; import { projectAuthCheck } from "./project_actions"; import { redisClient } from "../lib/redis"; import { authorizeUserAction, logUsage } from "./billing_actions"; import { USE_BILLING } from "../lib/feature_flags"; import { WithStringId } from "../lib/types/types"; import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot"; +import { container } from "@/di/container"; +import { IUsageQuotaPolicyService } from "@/src/application/services/usage-quota-policy.service.interface"; + +const usageQuotaPolicyService = container.resolve('usageQuotaPolicyService'); export async function getCopilotResponseStream( projectId: string, @@ -26,9 +28,7 @@ export async function getCopilotResponseStream( streamId: string; } | { billingError: string }> { await projectAuthCheck(projectId); - if (!await check_query_limit(projectId)) { - throw new QueryLimitError(); - } + await usageQuotaPolicyService.assertAndConsume(projectId); // Check billing authorization const authResponse = await authorizeUserAction({ @@ -39,9 +39,7 @@ export async function getCopilotResponseStream( return { billingError: authResponse.error || 'Billing error' }; } - if (!await check_query_limit(projectId)) { - throw new QueryLimitError(); - } + await usageQuotaPolicyService.assertAndConsume(projectId); // prepare request const request: z.infer = { @@ -73,9 +71,7 @@ export async function getCopilotAgentInstructions( agentName: string, ): Promise { await projectAuthCheck(projectId); - if (!await check_query_limit(projectId)) { - throw new QueryLimitError(); - } + await usageQuotaPolicyService.assertAndConsume(projectId); // Check billing authorization const authResponse = await authorizeUserAction({ diff --git a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts index e0101da0..79d6f777 100644 --- a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts +++ b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts @@ -4,12 +4,13 @@ import { projectsCollection, chatsCollection, chatMessagesCollection } from "../ import { z } from "zod"; import { ObjectId, WithId } from "mongodb"; import { authCheck } from "../../../utils"; -import { check_query_limit } from "../../../../../../lib/rate_limiting"; import { PrefixLogger } from "../../../../../../lib/utils"; import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing"; import { USE_BILLING } from "@/app/lib/feature_flags"; import { getResponse } from "@/app/lib/agents"; import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "@/app/lib/types/types"; +import { IUsageQuotaPolicyService } from "@/src/application/services/usage-quota-policy.service.interface"; +import { container } from "@/di/container"; function convert(messages: z.infer[]): z.infer[] { const result: z.infer[] = []; @@ -123,11 +124,9 @@ export async function POST( billingCustomerId = await getCustomerIdForProject(session.projectId); } - // check query limit - if (!await check_query_limit(session.projectId)) { - logger.log(`Query limit exceeded for project ${session.projectId}`); - return Response.json({ error: "Query limit exceeded" }, { status: 429 }); - } + // assert and consume quota + const usageQuotaPolicyService = container.resolve('usageQuotaPolicyService'); + await usageQuotaPolicyService.assertAndConsume(session.projectId); // parse and validate the request body let body; diff --git a/apps/rowboat/app/lib/rate_limiting.ts b/apps/rowboat/app/lib/rate_limiting.ts deleted file mode 100644 index 99ea4663..00000000 --- a/apps/rowboat/app/lib/rate_limiting.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { redisClient } from "./redis"; - -const MAX_QUERIES_PER_MINUTE = Number(process.env.MAX_QUERIES_PER_MINUTE) || 0; - -export async function check_query_limit(projectId: string): Promise { - // if the limit is 0, we don't want to check the limit - if (MAX_QUERIES_PER_MINUTE === 0) { - return true; - } - - const minutes_since_epoch = Math.floor(Date.now() / 1000 / 60); // 60 second window - const key = `rate_limit:${projectId}:${minutes_since_epoch}`; - - // increment the counter and return the count - const count = await redisClient.incr(key); - if (count === 1) { - await redisClient.expire(key, 70); // Set TTL to clean up automatically - } - - return count <= MAX_QUERIES_PER_MINUTE; -} \ No newline at end of file diff --git a/apps/rowboat/di/container.ts b/apps/rowboat/di/container.ts index 423a44c6..8a8af2da 100644 --- a/apps/rowboat/di/container.ts +++ b/apps/rowboat/di/container.ts @@ -9,6 +9,7 @@ import { CreateCachedTurnUseCase } from "@/src/application/use-cases/conversatio import { FetchCachedTurnUseCase } from "@/src/application/use-cases/conversations/fetch-cached-turn.use-case"; import { CreateCachedTurnController } from "@/src/interface-adapters/controllers/conversations/create-cached-turn.controller"; import { RunTurnController } from "@/src/interface-adapters/controllers/conversations/run-turn.controller"; +import { RedisUsageQuotaPolicyService } from "@/src/infrastructure/services/redis.usage-quota-policy.service"; export const container = createContainer({ injectionMode: InjectionMode.PROXY, @@ -19,6 +20,7 @@ container.register({ // services // --- cacheService: asClass(RedisCacheService).singleton(), + usageQuotaPolicyService: asClass(RedisUsageQuotaPolicyService).singleton(), // conversations // --- diff --git a/apps/rowboat/src/application/services/usage-quota-policy.service.interface.ts b/apps/rowboat/src/application/services/usage-quota-policy.service.interface.ts new file mode 100644 index 00000000..6defb4bb --- /dev/null +++ b/apps/rowboat/src/application/services/usage-quota-policy.service.interface.ts @@ -0,0 +1,4 @@ +export interface IUsageQuotaPolicyService { + // this method will throw a QuotaExceededError if the quota is exceeded + assertAndConsume(projectId: string): Promise; +} \ No newline at end of file diff --git a/apps/rowboat/src/application/use-cases/conversations/create-cached-turn.use-case.ts b/apps/rowboat/src/application/use-cases/conversations/create-cached-turn.use-case.ts index 315c87f4..16476d7f 100644 --- a/apps/rowboat/src/application/use-cases/conversations/create-cached-turn.use-case.ts +++ b/apps/rowboat/src/application/use-cases/conversations/create-cached-turn.use-case.ts @@ -1,12 +1,11 @@ import { BadRequestError, NotAuthorizedError, NotFoundError } from '@/src/entities/errors/common'; -import { check_query_limit } from "@/app/lib/rate_limiting"; -import { QueryLimitError } from "@/src/entities/errors/common"; import { apiKeysCollection, projectMembersCollection } from "@/app/lib/mongodb"; import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface"; import { z } from "zod"; import { nanoid } from 'nanoid'; import { ICacheService } from '@/src/application/services/cache.service.interface'; import { CachedTurnRequest, Turn } from '@/src/entities/models/turn'; +import { IUsageQuotaPolicyService } from '../../services/usage-quota-policy.service.interface'; const inputSchema = z.object({ caller: z.enum(["user", "api"]), @@ -23,16 +22,20 @@ export interface ICreateCachedTurnUseCase { export class CreateCachedTurnUseCase implements ICreateCachedTurnUseCase { private readonly cacheService: ICacheService; private readonly conversationsRepository: IConversationsRepository; + private readonly usageQuotaPolicyService: IUsageQuotaPolicyService; constructor({ cacheService, conversationsRepository, + usageQuotaPolicyService, }: { cacheService: ICacheService, conversationsRepository: IConversationsRepository, + usageQuotaPolicyService: IUsageQuotaPolicyService, }) { this.cacheService = cacheService; this.conversationsRepository = conversationsRepository; + this.usageQuotaPolicyService = usageQuotaPolicyService; } async execute(data: z.infer): Promise<{ key: string }> { @@ -45,10 +48,8 @@ export class CreateCachedTurnUseCase implements ICreateCachedTurnUseCase { // extract projectid from conversation const { projectId } = conversation; - // check query limit for project - if (!await check_query_limit(projectId)) { - throw new QueryLimitError('Query limit exceeded'); - } + // assert and consume quota + await this.usageQuotaPolicyService.assertAndConsume(projectId); // if caller is a user, ensure they are a member of project if (data.caller === "user") { diff --git a/apps/rowboat/src/application/use-cases/conversations/create-conversation.use-case.ts b/apps/rowboat/src/application/use-cases/conversations/create-conversation.use-case.ts index aca18d01..ad29e27b 100644 --- a/apps/rowboat/src/application/use-cases/conversations/create-conversation.use-case.ts +++ b/apps/rowboat/src/application/use-cases/conversations/create-conversation.use-case.ts @@ -1,11 +1,10 @@ import { BadRequestError, NotAuthorizedError, NotFoundError } from '@/src/entities/errors/common'; -import { check_query_limit } from "@/app/lib/rate_limiting"; -import { QueryLimitError } from "@/src/entities/errors/common"; import { apiKeysCollection, projectMembersCollection, projectsCollection } from "@/app/lib/mongodb"; import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface"; import { z } from "zod"; import { Conversation } from "@/src/entities/models/conversation"; import { Workflow } from "@/app/lib/types/workflow_types"; +import { IUsageQuotaPolicyService } from '../../services/usage-quota-policy.service.interface'; const inputSchema = z.object({ caller: z.enum(["user", "api"]), @@ -22,13 +21,17 @@ export interface ICreateConversationUseCase { export class CreateConversationUseCase implements ICreateConversationUseCase { private readonly conversationsRepository: IConversationsRepository; + private readonly usageQuotaPolicyService: IUsageQuotaPolicyService; constructor({ conversationsRepository, + usageQuotaPolicyService, }: { conversationsRepository: IConversationsRepository, + usageQuotaPolicyService: IUsageQuotaPolicyService, }) { this.conversationsRepository = conversationsRepository; + this.usageQuotaPolicyService = usageQuotaPolicyService; } async execute(data: z.infer): Promise> { @@ -36,10 +39,8 @@ export class CreateConversationUseCase implements ICreateConversationUseCase { let isLiveWorkflow = Boolean(data.isLiveWorkflow); let workflow = data.workflow; - // check query limit for project - if (!await check_query_limit(projectId)) { - throw new QueryLimitError('Query limit exceeded'); - } + // assert and consume quota + await this.usageQuotaPolicyService.assertAndConsume(projectId); // if caller is a user, ensure they are a member of project if (caller === "user") { diff --git a/apps/rowboat/src/application/use-cases/conversations/fetch-cached-turn.use-case.ts b/apps/rowboat/src/application/use-cases/conversations/fetch-cached-turn.use-case.ts index 3218ee56..c17e2ff8 100644 --- a/apps/rowboat/src/application/use-cases/conversations/fetch-cached-turn.use-case.ts +++ b/apps/rowboat/src/application/use-cases/conversations/fetch-cached-turn.use-case.ts @@ -1,11 +1,10 @@ import { BadRequestError, NotAuthorizedError, NotFoundError } from '@/src/entities/errors/common'; -import { check_query_limit } from "@/app/lib/rate_limiting"; -import { QueryLimitError } from "@/src/entities/errors/common"; import { apiKeysCollection, projectMembersCollection } from "@/app/lib/mongodb"; import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface"; import { z } from "zod"; import { ICacheService } from '@/src/application/services/cache.service.interface'; import { CachedTurnRequest, Turn } from '@/src/entities/models/turn'; +import { IUsageQuotaPolicyService } from '../../services/usage-quota-policy.service.interface'; const inputSchema = z.object({ caller: z.enum(["user", "api"]), @@ -21,16 +20,20 @@ export interface IFetchCachedTurnUseCase { export class FetchCachedTurnUseCase implements IFetchCachedTurnUseCase { private readonly cacheService: ICacheService; private readonly conversationsRepository: IConversationsRepository; + private readonly usageQuotaPolicyService: IUsageQuotaPolicyService; constructor({ cacheService, conversationsRepository, + usageQuotaPolicyService, }: { cacheService: ICacheService, conversationsRepository: IConversationsRepository, + usageQuotaPolicyService: IUsageQuotaPolicyService, }) { this.cacheService = cacheService; this.conversationsRepository = conversationsRepository; + this.usageQuotaPolicyService = usageQuotaPolicyService; } async execute(data: z.infer): Promise> { @@ -52,10 +55,8 @@ export class FetchCachedTurnUseCase implements IFetchCachedTurnUseCase { // extract projectid from conversation const { projectId } = conversation; - // check query limit for project - if (!await check_query_limit(projectId)) { - throw new QueryLimitError('Query limit exceeded'); - } + // assert and consume quota + await this.usageQuotaPolicyService.assertAndConsume(projectId); // if caller is a user, ensure they are a member of project if (data.caller === "user") { diff --git a/apps/rowboat/src/application/use-cases/conversations/run-conversation-turn.use-case.ts b/apps/rowboat/src/application/use-cases/conversations/run-conversation-turn.use-case.ts index 7cf3d7d2..3d464cbc 100644 --- a/apps/rowboat/src/application/use-cases/conversations/run-conversation-turn.use-case.ts +++ b/apps/rowboat/src/application/use-cases/conversations/run-conversation-turn.use-case.ts @@ -2,13 +2,12 @@ import { Turn, TurnEvent } from "@/src/entities/models/turn"; import { USE_BILLING } from "@/app/lib/feature_flags"; import { authorize, getCustomerIdForProject } from "@/app/lib/billing"; import { BadRequestError, BillingError, NotAuthorizedError, NotFoundError } from '@/src/entities/errors/common'; -import { check_query_limit } from "@/app/lib/rate_limiting"; -import { QueryLimitError } from "@/src/entities/errors/common"; import { apiKeysCollection, projectMembersCollection } from "@/app/lib/mongodb"; import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface"; import { streamResponse } from "@/app/lib/agents"; import { z } from "zod"; import { Message } from "@/app/lib/types/types"; +import { IUsageQuotaPolicyService } from '../../services/usage-quota-policy.service.interface'; const inputSchema = z.object({ caller: z.enum(["user", "api"]), @@ -25,13 +24,17 @@ export interface IRunConversationTurnUseCase { export class RunConversationTurnUseCase implements IRunConversationTurnUseCase { private readonly conversationsRepository: IConversationsRepository; + private readonly usageQuotaPolicyService: IUsageQuotaPolicyService; constructor({ conversationsRepository, + usageQuotaPolicyService, }: { conversationsRepository: IConversationsRepository, + usageQuotaPolicyService: IUsageQuotaPolicyService, }) { this.conversationsRepository = conversationsRepository; + this.usageQuotaPolicyService = usageQuotaPolicyService; } async *execute(data: z.infer): AsyncGenerator, void, unknown> { @@ -44,10 +47,8 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase { // extract projectid from conversation const { id: conversationId, projectId } = conversation; - // check query limit for project - if (!await check_query_limit(projectId)) { - throw new QueryLimitError('Query limit exceeded'); - } + // assert and consume quota + await this.usageQuotaPolicyService.assertAndConsume(projectId); // if caller is a user, ensure they are a member of project if (data.caller === "user") { diff --git a/apps/rowboat/src/entities/errors/common.ts b/apps/rowboat/src/entities/errors/common.ts index e13c09bb..f29e18c7 100644 --- a/apps/rowboat/src/entities/errors/common.ts +++ b/apps/rowboat/src/entities/errors/common.ts @@ -4,7 +4,7 @@ export class BillingError extends Error { } } -export class QueryLimitError extends Error { +export class QuotaExceededError extends Error { constructor(message?: string, options?: ErrorOptions) { super(message, options); } diff --git a/apps/rowboat/src/infrastructure/services/redis.usage-quota-policy.service.ts b/apps/rowboat/src/infrastructure/services/redis.usage-quota-policy.service.ts new file mode 100644 index 00000000..d3b4c717 --- /dev/null +++ b/apps/rowboat/src/infrastructure/services/redis.usage-quota-policy.service.ts @@ -0,0 +1,25 @@ +import { IUsageQuotaPolicyService } from "@/src/application/services/usage-quota-policy.service.interface"; +import { redisClient } from "@/app/lib/redis"; +import { QuotaExceededError } from "@/src/entities/errors/common"; + +const MAX_QUERIES_PER_MINUTE = Number(process.env.MAX_QUERIES_PER_MINUTE) || 0; + +export class RedisUsageQuotaPolicyService implements IUsageQuotaPolicyService { + async assertAndConsume(projectId: string): Promise { + if (MAX_QUERIES_PER_MINUTE === 0) { + return; + } + + const minutes_since_epoch = Math.floor(Date.now() / 1000 / 60); // 60 second window + const key = `rate_limit:${projectId}:${minutes_since_epoch}`; + + const count = await redisClient.incr(key); + if (count === 1) { + await redisClient.expire(key, 70); // Set TTL to clean up automatically + } + + if (count > MAX_QUERIES_PER_MINUTE) { + throw new QuotaExceededError(`Quota exceeded for project ${projectId}`); + } + } +} \ No newline at end of file