From 2fda9a7e79430b3cb2dc246982dc7b58c65bbc66 Mon Sep 17 00:00:00 2001 From: Ramnique Singh <30795890+ramnique@users.noreply.github.com> Date: Sun, 18 May 2025 01:37:54 +0530 Subject: [PATCH] add stripe billing --- apps/copilot/app.py | 1 + apps/rowboat/app/actions/actions.ts | 56 +-- apps/rowboat/app/actions/auth_actions.ts | 53 +++ apps/rowboat/app/actions/billing_actions.ts | 95 ++++ apps/rowboat/app/actions/copilot_actions.ts | 50 ++- .../rowboat/app/actions/datasource_actions.ts | 4 + apps/rowboat/app/actions/klavis_actions.ts | 28 +- apps/rowboat/app/actions/project_actions.ts | 61 ++- .../[streamId]/route.ts | 24 + .../api/stream-response/[streamId]/route.ts | 63 ++- .../app/api/v1/[projectId]/chat/route.ts | 34 ++ .../widget/v1/chats/[chatId]/turn/route.ts | 34 ++ apps/rowboat/app/billing/app.tsx | 148 +++++++ apps/rowboat/app/billing/callback/page.tsx | 18 + apps/rowboat/app/billing/layout.tsx | 13 + apps/rowboat/app/billing/page.tsx | 17 + apps/rowboat/app/lib/auth.ts | 116 +++++ apps/rowboat/app/lib/billing.ts | 304 +++++++++++++ .../app/lib/components/user_button.tsx | 13 +- apps/rowboat/app/lib/feature_flags.ts | 1 + apps/rowboat/app/lib/mongodb.ts | 3 +- apps/rowboat/app/lib/types/billing_types.ts | 100 +++++ apps/rowboat/app/lib/types/copilot_types.ts | 1 + .../rowboat/app/lib/types/datasource_types.ts | 1 + apps/rowboat/app/lib/types/types.ts | 9 + apps/rowboat/app/onboarding/app.tsx | 90 ++++ apps/rowboat/app/onboarding/layout.tsx | 13 + apps/rowboat/app/onboarding/page.tsx | 14 + .../app/projects/[projectId]/config/page.tsx | 4 +- .../app/projects/[projectId]/copilot/app.tsx | 139 +++--- .../[projectId]/copilot/use-copilot.tsx | 29 +- .../[projectId]/entities/agent_config.tsx | 197 ++++++--- .../rowboat/app/projects/[projectId]/page.tsx | 4 +- .../playground/components/chat.tsx | 34 +- .../[projectId]/sources/[sourceId]/page.tsx | 2 + .../sources/[sourceId]/source-page.tsx | 78 +++- .../projects/[projectId]/sources/new/page.tsx | 2 + .../app/projects/[projectId]/sources/page.tsx | 2 + .../[projectId]/test/[[...slug]]/page.tsx | 7 +- .../tools/components/HostedServers.tsx | 28 +- .../app/projects/[projectId]/tools/page.tsx | 5 +- .../app/projects/[projectId]/workflow/app.tsx | 29 +- .../projects/[projectId]/workflow/page.tsx | 3 + .../[projectId]/workflow/workflow_editor.tsx | 4 + apps/rowboat/app/projects/layout.tsx | 4 +- .../projects/layout/components/app-layout.tsx | 13 +- .../projects/layout/components/sidebar.tsx | 13 +- apps/rowboat/app/projects/page.tsx | 6 +- .../select/components/create-project.tsx | 411 +++++++++--------- apps/rowboat/app/projects/select/page.tsx | 4 +- apps/rowboat/app/scripts/rag_files_worker.ts | 70 ++- apps/rowboat/app/scripts/rag_text_worker.ts | 64 ++- apps/rowboat/app/scripts/rag_urls_worker.ts | 64 ++- .../common/billing-upgrade-modal.tsx | 215 +++++++++ apps/rowboat/middleware.ts | 16 +- apps/rowboat/package-lock.json | 9 +- apps/rowboat/package.json | 1 + docker-compose.yml | 12 + 58 files changed, 2348 insertions(+), 485 deletions(-) create mode 100644 apps/rowboat/app/actions/auth_actions.ts create mode 100644 apps/rowboat/app/actions/billing_actions.ts create mode 100644 apps/rowboat/app/billing/app.tsx create mode 100644 apps/rowboat/app/billing/callback/page.tsx create mode 100644 apps/rowboat/app/billing/layout.tsx create mode 100644 apps/rowboat/app/billing/page.tsx create mode 100644 apps/rowboat/app/lib/auth.ts create mode 100644 apps/rowboat/app/lib/billing.ts create mode 100644 apps/rowboat/app/lib/types/billing_types.ts create mode 100644 apps/rowboat/app/onboarding/app.tsx create mode 100644 apps/rowboat/app/onboarding/layout.tsx create mode 100644 apps/rowboat/app/onboarding/page.tsx create mode 100644 apps/rowboat/components/common/billing-upgrade-modal.tsx diff --git a/apps/copilot/app.py b/apps/copilot/app.py index 26672d26..9d23990a 100644 --- a/apps/copilot/app.py +++ b/apps/copilot/app.py @@ -22,6 +22,7 @@ class DataSource(BaseModel): populate_by_name = True class ApiRequest(BaseModel): + projectId: str messages: List[UserMessage | AssistantMessage] workflow_schema: str current_workflow_config: str diff --git a/apps/rowboat/app/actions/actions.ts b/apps/rowboat/app/actions/actions.ts index a9596788..e72efe79 100644 --- a/apps/rowboat/app/actions/actions.ts +++ b/apps/rowboat/app/actions/actions.ts @@ -1,35 +1,18 @@ 'use server'; -import { AgenticAPIInitStreamResponse, convertFromAgenticAPIChatMessages } from "../lib/types/agents_api_types"; +import { AgenticAPIInitStreamResponse } from "../lib/types/agents_api_types"; import { AgenticAPIChatRequest } from "../lib/types/agents_api_types"; import { WebpageCrawlResponse } from "../lib/types/tool_types"; import { webpagesCollection } from "../lib/mongodb"; import { z } from 'zod'; import FirecrawlApp, { ScrapeResponse } from '@mendable/firecrawl-js'; -import { apiV1 } from "rowboat-shared"; -import { Claims, getSession } from "@auth0/nextjs-auth0"; -import { getAgenticApiResponse, getAgenticResponseStreamId } from "../lib/utils"; +import { getAgenticResponseStreamId } from "../lib/utils"; import { check_query_limit } from "../lib/rate_limiting"; import { QueryLimitError } from "../lib/client_utils"; import { projectAuthCheck } from "./project_actions"; -import { USE_AUTH } from "../lib/feature_flags"; +import { authorizeUserAction } from "./billing_actions"; const crawler = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY || '' }); -export async function authCheck(): Promise { - if (!USE_AUTH) { - return { - email: 'guestuser@rowboatlabs.com', - email_verified: true, - sub: 'guest_user', - }; - } - const { user } = await getSession() || {}; - if (!user) { - throw new Error('User not authenticated'); - } - return user; -} - export async function scrapeWebpage(url: string): Promise> { const page = await webpagesCollection.findOne({ "_id": url, @@ -74,30 +57,25 @@ export async function scrapeWebpage(url: string): Promise): Promise<{ - messages: z.infer[], - state: unknown, - rawRequest: unknown, - rawResponse: unknown, -}> { +export async function getAssistantResponseStreamId(request: z.infer): Promise | { billingError: string }> { await projectAuthCheck(request.projectId); if (!await check_query_limit(request.projectId)) { throw new QueryLimitError(); } - const response = await getAgenticApiResponse(request); - return { - messages: convertFromAgenticAPIChatMessages(response.messages), - state: response.state, - rawRequest: request, - rawResponse: response.rawAPIResponse, - }; -} - -export async function getAssistantResponseStreamId(request: z.infer): Promise> { - await projectAuthCheck(request.projectId); - if (!await check_query_limit(request.projectId)) { - throw new QueryLimitError(); + // Check billing authorization + const agentModels = request.agents.reduce((acc, agent) => { + acc.push(agent.model); + return acc; + }, [] as string[]); + const { success, error } = await authorizeUserAction({ + type: 'agent_response', + data: { + agentModels, + }, + }); + if (!success) { + return { billingError: error || 'Billing error' }; } const response = await getAgenticResponseStreamId(request); diff --git a/apps/rowboat/app/actions/auth_actions.ts b/apps/rowboat/app/actions/auth_actions.ts new file mode 100644 index 00000000..4fa6ecfc --- /dev/null +++ b/apps/rowboat/app/actions/auth_actions.ts @@ -0,0 +1,53 @@ +"use server"; +import { getSession } from "@auth0/nextjs-auth0"; +import { USE_AUTH } from "../lib/feature_flags"; +import { WithStringId, User } from "../lib/types/types"; +import { getUserFromSessionId, GUEST_DB_USER } from "../lib/auth"; +import { z } from "zod"; +import { ObjectId } from "mongodb"; +import { usersCollection } from "../lib/mongodb"; + +export async function authCheck(): Promise>> { + if (!USE_AUTH) { + return GUEST_DB_USER; + } + + const { user } = await getSession() || {}; + if (!user) { + throw new Error('User not authenticated'); + } + + const dbUser = await getUserFromSessionId(user.sub); + if (!dbUser) { + throw new Error('User record not found'); + } + return dbUser; +} + +const EmailOnly = z.object({ + email: z.string().email(), +}); + +export async function updateUserEmail(email: string) { + if (!USE_AUTH) { + return; + } + const user = await authCheck(); + + if (!email.trim()) { + throw new Error('Email is required'); + } + if (!EmailOnly.safeParse({ email }).success) { + throw new Error('Invalid email'); + } + + // update customer email in db + await usersCollection.updateOne({ + _id: new ObjectId(user._id), + }, { + $set: { + email, + updatedAt: new Date().toISOString(), + } + }); +} diff --git a/apps/rowboat/app/actions/billing_actions.ts b/apps/rowboat/app/actions/billing_actions.ts new file mode 100644 index 00000000..eafde06f --- /dev/null +++ b/apps/rowboat/app/actions/billing_actions.ts @@ -0,0 +1,95 @@ +"use server"; +import { + authorize, + logUsage as libLogUsage, + getBillingCustomer, + createCustomerPortalSession, + getPrices as libGetPrices, + updateSubscriptionPlan as libUpdateSubscriptionPlan, + getEligibleModels as libGetEligibleModels +} from "../lib/billing"; +import { authCheck } from "./auth_actions"; +import { USE_BILLING } from "../lib/feature_flags"; +import { + AuthorizeRequest, + AuthorizeResponse, + LogUsageRequest, + Customer, + PricesResponse, + SubscriptionPlan, + UpdateSubscriptionPlanRequest, + ModelsResponse +} from "../lib/types/billing_types"; +import { z } from "zod"; +import { WithStringId } from "../lib/types/types"; + +export async function getCustomer(): Promise>> { + const user = await authCheck(); + if (!user.billingCustomerId) { + throw new Error("Customer not found"); + } + const customer = await getBillingCustomer(user.billingCustomerId); + if (!customer) { + throw new Error("Customer not found"); + } + return customer; +} + +export async function authorizeUserAction(request: z.infer): Promise> { + if (!USE_BILLING) { + return { success: true }; + } + + const customer = await getCustomer(); + const response = await authorize(customer._id, request); + return response; +} + +export async function logUsage(request: z.infer) { + if (!USE_BILLING) { + return; + } + + const customer = await getCustomer(); + await libLogUsage(customer._id, request); + return; +} + +export async function getCustomerPortalUrl(returnUrl: string): Promise { + if (!USE_BILLING) { + throw new Error("Billing is not enabled") + } + + const customer = await getCustomer(); + return await createCustomerPortalSession(customer._id, returnUrl); +} + +export async function getPrices(): Promise> { + if (!USE_BILLING) { + throw new Error("Billing is not enabled"); + } + + const response = await libGetPrices(); + return response; +} + +export async function updateSubscriptionPlan(plan: z.infer, returnUrl: string): Promise { + if (!USE_BILLING) { + throw new Error("Billing is not enabled"); + } + + const customer = await getCustomer(); + const request: z.infer = { plan, returnUrl }; + const url = await libUpdateSubscriptionPlan(customer._id, request); + return url; +} + +export async function getEligibleModels(): Promise | "*"> { + if (!USE_BILLING) { + return "*"; + } + + const customer = await getCustomer(); + const response = await libGetEligibleModels(customer._id); + return response; +} \ No newline at end of file diff --git a/apps/rowboat/app/actions/copilot_actions.ts b/apps/rowboat/app/actions/copilot_actions.ts index 6162d5d4..582c5a2c 100644 --- a/apps/rowboat/app/actions/copilot_actions.ts +++ b/apps/rowboat/app/actions/copilot_actions.ts @@ -17,6 +17,8 @@ import { projectAuthCheck } from "./project_actions"; import { redisClient } from "../lib/redis"; import { fetchProjectMcpTools } from "../lib/project_tools"; import { mergeProjectTools } from "../lib/types/project_types"; +import { authorizeUserAction, logUsage } from "./billing_actions"; +import { USE_BILLING } from "../lib/feature_flags"; export async function getCopilotResponse( projectId: string, @@ -28,12 +30,21 @@ export async function getCopilotResponse( message: z.infer; rawRequest: unknown; rawResponse: unknown; -}> { +} | { billingError: string }> { await projectAuthCheck(projectId); if (!await check_query_limit(projectId)) { throw new QueryLimitError(); } + // Check billing authorization + const authResponse = await authorizeUserAction({ + type: 'copilot_request', + data: {}, + }); + if (!authResponse.success) { + return { billingError: authResponse.error || 'Billing error' }; + } + // Get MCP tools from project and merge with workflow tools const mcpTools = await fetchProjectMcpTools(projectId); @@ -45,6 +56,7 @@ export async function getCopilotResponse( // prepare request const request: z.infer = { + projectId: projectId, messages: messages.map(convertToCopilotApiMessage), workflow_schema: JSON.stringify(zodToJsonSchema(CopilotWorkflow)), current_workflow_config: JSON.stringify(copilotWorkflow), @@ -132,12 +144,25 @@ export async function getCopilotResponseStream( dataSources?: z.infer[] ): Promise<{ streamId: string; -}> { +} | { billingError: string }> { await projectAuthCheck(projectId); if (!await check_query_limit(projectId)) { throw new QueryLimitError(); } + // Check billing authorization + const authResponse = await authorizeUserAction({ + type: 'copilot_request', + data: {}, + }); + if (!authResponse.success) { + return { billingError: authResponse.error || 'Billing error' }; + } + + if (!await check_query_limit(projectId)) { + throw new QueryLimitError(); + } + // Get MCP tools from project and merge with workflow tools const mcpTools = await fetchProjectMcpTools(projectId); @@ -149,6 +174,7 @@ export async function getCopilotResponseStream( // prepare request const request: z.infer = { + projectId: projectId, messages: messages.map(convertToCopilotApiMessage), workflow_schema: JSON.stringify(zodToJsonSchema(CopilotWorkflow)), current_workflow_config: JSON.stringify(copilotWorkflow), @@ -177,12 +203,21 @@ export async function getCopilotAgentInstructions( messages: z.infer[], current_workflow_config: z.infer, agentName: string, -): Promise { +): Promise { await projectAuthCheck(projectId); if (!await check_query_limit(projectId)) { throw new QueryLimitError(); } + // Check billing authorization + const authResponse = await authorizeUserAction({ + type: 'copilot_request', + data: {}, + }); + if (!authResponse.success) { + return { billingError: authResponse.error || 'Billing error' }; + } + // Get MCP tools from project and merge with workflow tools const mcpTools = await fetchProjectMcpTools(projectId); @@ -194,6 +229,7 @@ export async function getCopilotAgentInstructions( // prepare request const request: z.infer = { + projectId: projectId, messages: messages.map(convertToCopilotApiMessage), workflow_schema: JSON.stringify(zodToJsonSchema(CopilotWorkflow)), current_workflow_config: JSON.stringify(copilotWorkflow), @@ -237,6 +273,14 @@ export async function getCopilotAgentInstructions( throw new Error(`Failed to call copilot api: ${copilotResponse.error}`); } + // log the billing usage + if (USE_BILLING) { + await logUsage({ + type: 'copilot_requests', + amount: 1, + }); + } + // return response return agent_instructions; } \ No newline at end of file diff --git a/apps/rowboat/app/actions/datasource_actions.ts b/apps/rowboat/app/actions/datasource_actions.ts index cbaced48..c9d233f6 100644 --- a/apps/rowboat/app/actions/datasource_actions.ts +++ b/apps/rowboat/app/actions/datasource_actions.ts @@ -105,6 +105,7 @@ export async function recrawlWebDataSource(projectId: string, sourceId: string) }, { $set: { status: 'pending', + billingError: undefined, lastUpdatedAt: (new Date()).toISOString(), attempts: 0, }, @@ -124,6 +125,7 @@ export async function deleteDataSource(projectId: string, sourceId: string) { }, { $set: { status: 'deleted', + billingError: undefined, lastUpdatedAt: (new Date()).toISOString(), attempts: 0, }, @@ -189,6 +191,7 @@ export async function addDocsToDataSource({ { $set: { status: 'pending', + billingError: undefined, attempts: 0, lastUpdatedAt: new Date().toISOString(), }, @@ -275,6 +278,7 @@ export async function deleteDocsFromDataSource({ }, { $set: { status: 'pending', + billingError: undefined, attempts: 0, lastUpdatedAt: new Date().toISOString(), }, diff --git a/apps/rowboat/app/actions/klavis_actions.ts b/apps/rowboat/app/actions/klavis_actions.ts index c8ea00b9..8ed4812e 100644 --- a/apps/rowboat/app/actions/klavis_actions.ts +++ b/apps/rowboat/app/actions/klavis_actions.ts @@ -7,6 +7,8 @@ import { projectsCollection } from '../lib/mongodb'; import { fetchMcpTools, toggleMcpTool } from './mcp_actions'; import { fetchMcpToolsForServer } from './mcp_actions'; import { headers } from 'next/headers'; +import { authorizeUserAction } from './billing_actions'; +import { redisClient } from '../lib/redis'; type McpServerType = z.infer; type McpToolType = z.infer; @@ -542,13 +544,34 @@ export async function enableServer( serverName: string, projectId: string, enabled: boolean -): Promise { +): Promise { try { await projectAuthCheck(projectId); console.log('[Klavis API] Toggle server request:', { serverName, projectId, enabled }); if (enabled) { + // get count of enabled hosted mcp servers for this project + const existingInstances = await listActiveServerInstances(projectId); + // billing limit check + const authResponse = await authorizeUserAction({ + type: 'enable_hosted_tool_server', + data: { + existingServerCount: existingInstances.length, + }, + }); + if (!authResponse.success) { + return { billingError: authResponse.error || 'Billing error' }; + } + + // set key in redis to indicate that a server is being enabled on this project + // the key set should only succeed if the key does not already exist + const setResult = await redisClient.set(`klavis_enabling_server:${projectId}`, 'true', { EX: 60 * 60, NX: true }); + console.log('[redis] Set result here:', setResult); + if (setResult !== 'OK') { + throw new Error("A server is already being enabled on this project"); + } + console.log(`[Klavis API] Creating server instance for ${serverName}...`); const result = await createMcpServerInstance(serverName, projectId, "Rowboat"); console.log('[Klavis API] Server instance created:', { @@ -640,6 +663,9 @@ export async function enableServer( console.error(`[Klavis API] Tool enrichment failed for ${serverName}:`, enrichError); } + // remove key from redis + await redisClient.del(`klavis_enabling_server:${projectId}`); + return result; } else { // Get active instances to find the one to delete diff --git a/apps/rowboat/app/actions/project_actions.ts b/apps/rowboat/app/actions/project_actions.ts index 832042e7..3a72737a 100644 --- a/apps/rowboat/app/actions/project_actions.ts +++ b/apps/rowboat/app/actions/project_actions.ts @@ -6,12 +6,13 @@ import { z } from 'zod'; import crypto from 'crypto'; import { revalidatePath } from "next/cache"; import { templates } from "../lib/project_templates"; -import { authCheck } from "./actions"; -import { WithStringId } from "../lib/types/types"; +import { authCheck } from "./auth_actions"; +import { User, WithStringId } from "../lib/types/types"; import { ApiKey } from "../lib/types/project_types"; import { Project } from "../lib/types/project_types"; import { USE_AUTH } from "../lib/feature_flags"; import { deleteMcpServerInstance, listActiveServerInstances } from "./klavis_actions"; +import { authorizeUserAction } from "./billing_actions"; export async function projectAuthCheck(projectId: string) { if (!USE_AUTH) { @@ -20,23 +21,27 @@ export async function projectAuthCheck(projectId: string) { const user = await authCheck(); const membership = await projectMembersCollection.findOne({ projectId, - userId: user.sub, + userId: user._id, }); if (!membership) { throw new Error('User not a member of project'); } } -async function createBaseProject(name: string, user: any) { - // Check project limits - const projectsLimit = Number(process.env.MAX_PROJECTS_PER_USER) || 0; - if (projectsLimit > 0) { - const count = await projectsCollection.countDocuments({ - createdByUserId: user.sub, - }); - if (count >= projectsLimit) { - throw new Error('You have reached your project limit. Please upgrade your plan.'); - } +async function createBaseProject(name: string, user: WithStringId>): Promise<{ id: string } | { billingError: string }> { + // fetch project count for this user + const projectCount = await projectsCollection.countDocuments({ + createdByUserId: user._id, + }); + // billing limit check + const authResponse = await authorizeUserAction({ + type: 'create_project', + data: { + existingProjectCount: projectCount, + }, + }); + if (!authResponse.success) { + return { billingError: authResponse.error || 'Billing error' }; } const projectId = crypto.randomUUID(); @@ -49,7 +54,7 @@ async function createBaseProject(name: string, user: any) { name, createdAt: (new Date()).toISOString(), lastUpdatedAt: (new Date()).toISOString(), - createdByUserId: user.sub, + createdByUserId: user._id, chatClientId, secret, nextWorkflowNumber: 1, @@ -58,7 +63,7 @@ async function createBaseProject(name: string, user: any) { // Add user to project await projectMembersCollection.insertOne({ - userId: user.sub, + userId: user._id, projectId: projectId, createdAt: (new Date()).toISOString(), lastUpdatedAt: (new Date()).toISOString(), @@ -67,15 +72,20 @@ async function createBaseProject(name: string, user: any) { // Add first api key await createApiKey(projectId); - return projectId; + return { id: projectId }; } -export async function createProject(formData: FormData) { +export async function createProject(formData: FormData): Promise<{ id: string } | { billingError: string }> { const user = await authCheck(); const name = formData.get('name') as string; const templateKey = formData.get('template') as string; - const projectId = await createBaseProject(name, user); + const response = await createBaseProject(name, user); + if ('billingError' in response) { + return response; + } + + const projectId = response.id; // Add first workflow version with specified template const { agents, prompts, tools, startAgent } = templates[templateKey]; @@ -90,7 +100,7 @@ export async function createProject(formData: FormData) { name: `Version 1`, }); - redirect(`/projects/${projectId}/workflow`); + return { id: projectId }; } export async function getProjectConfig(projectId: string): Promise>> { @@ -107,7 +117,7 @@ export async function getProjectConfig(projectId: string): Promise[]> { const user = await authCheck(); const memberships = await projectMembersCollection.find({ - userId: user.sub, + userId: user._id, }).toArray(); const projectIds = memberships.map((m) => m.projectId); const projects = await projectsCollection.find({ @@ -271,11 +281,16 @@ export async function deleteProject(projectId: string) { redirect('/projects'); } -export async function createProjectFromPrompt(formData: FormData) { +export async function createProjectFromPrompt(formData: FormData): Promise<{ id: string } | { billingError: string }> { const user = await authCheck(); const name = formData.get('name') as string; - - const projectId = await createBaseProject(name, user); + + const response = await createBaseProject(name, user); + if ('billingError' in response) { + return response; + } + + const projectId = response.id; // Add first workflow version with default template const { agents, prompts, tools, startAgent } = templates['default']; diff --git a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts index 43c4bfa7..5bdd10a5 100644 --- a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts +++ b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts @@ -1,4 +1,7 @@ +import { getCustomerIdForProject, logUsage } from "@/app/lib/billing"; +import { USE_BILLING } from "@/app/lib/feature_flags"; import { redisClient } from "@/app/lib/redis"; +import { CopilotAPIRequest } from "@/app/lib/types/copilot_types"; export async function GET(request: Request, { params }: { params: { streamId: string } }) { // get the payload from redis @@ -7,6 +10,15 @@ export async function GET(request: Request, { params }: { params: { streamId: st return new Response("Stream not found", { status: 404 }); } + // parse the payload + const parsedPayload = CopilotAPIRequest.parse(JSON.parse(payload)); + + // fetch billing customer id + let billingCustomerId: string | null = null; + if (USE_BILLING) { + billingCustomerId = await getCustomerIdForProject(parsedPayload.projectId); + } + // Fetch the upstream SSE stream. const upstreamResponse = await fetch(`${process.env.COPILOT_API_URL}/chat_stream`, { method: 'POST', @@ -36,6 +48,18 @@ export async function GET(request: Request, { params }: { params: { streamId: st controller.enqueue(value); } controller.close(); + + // increment copilot request count in billing + if (USE_BILLING && billingCustomerId) { + try { + await logUsage(billingCustomerId, { + type: "copilot_requests", + amount: 1, + }); + } catch (error) { + console.error("Error logging usage", error); + } + } } catch (error) { controller.error(error); } diff --git a/apps/rowboat/app/api/stream-response/[streamId]/route.ts b/apps/rowboat/app/api/stream-response/[streamId]/route.ts index 06b0fc90..f42cd328 100644 --- a/apps/rowboat/app/api/stream-response/[streamId]/route.ts +++ b/apps/rowboat/app/api/stream-response/[streamId]/route.ts @@ -1,4 +1,8 @@ +import { getCustomerIdForProject, logUsage } from "@/app/lib/billing"; +import { USE_BILLING } from "@/app/lib/feature_flags"; import { redisClient } from "@/app/lib/redis"; +import { AgenticAPIChatMessage, AgenticAPIChatRequest, convertFromAgenticAPIChatMessages } from "@/app/lib/types/agents_api_types"; +import { createParser, type EventSourceMessage } from 'eventsource-parser'; export async function GET(request: Request, { params }: { params: { streamId: string } }) { // get the payload from redis @@ -7,6 +11,15 @@ export async function GET(request: Request, { params }: { params: { streamId: st return new Response("Stream not found", { status: 404 }); } + // parse the payload + const parsedPayload = AgenticAPIChatRequest.parse(JSON.parse(payload)); + + // fetch billing customer id + let billingCustomerId: string | null = null; + if (USE_BILLING) { + billingCustomerId = await getCustomerIdForProject(parsedPayload.projectId); + } + // Fetch the upstream SSE stream. const upstreamResponse = await fetch(`${process.env.AGENTS_API_URL}/chat_stream`, { method: 'POST', @@ -24,19 +37,63 @@ export async function GET(request: Request, { params }: { params: { streamId: st } const reader = upstreamResponse.body.getReader(); + const encoder = new TextEncoder(); const stream = new ReadableStream({ async start(controller) { + let messageCount = 0; + + function emitEvent(event: EventSourceMessage) { + // Re-emit the event in SSE format + let eventString = ''; + if (event.id) eventString += `id: ${event.id}\n`; + if (event.event) eventString += `event: ${event.event}\n`; + if (event.data) eventString += `data: ${event.data}\n`; + eventString += '\n'; + + controller.enqueue(encoder.encode(eventString)); + } + + const parser = createParser({ + onEvent(event: EventSourceMessage) { + if (event.event !== 'message') { + emitEvent(event); + return; + } + + // Parse message + const data = JSON.parse(event.data); + const msg = AgenticAPIChatMessage.parse(data); + const parsedMsg = convertFromAgenticAPIChatMessages([msg])[0]; + + // increment the message count if this is an assistant message + if (parsedMsg.role === 'assistant') { + messageCount++; + } + + // emit the event + emitEvent(event); + } + }); + try { - // Read from the upstream stream continuously. while (true) { const { done, value } = await reader.read(); if (done) break; - // Immediately enqueue each received chunk. - controller.enqueue(value); + + // Feed the chunk to the parser + parser.feed(new TextDecoder().decode(value)); } controller.close(); + + if (USE_BILLING && billingCustomerId) { + await logUsage(billingCustomerId, { + type: "agent_messages", + amount: messageCount, + }) + } } catch (error) { + console.error('Error processing stream:', error); controller.error(error); } }, diff --git a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts index a07beae6..2b1c3029 100644 --- a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts +++ b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts @@ -10,6 +10,8 @@ import { check_query_limit } from "../../../../lib/rate_limiting"; import { PrefixLogger } from "../../../../lib/utils"; import { TestProfile } from "@/app/lib/types/testing_types"; import { fetchProjectMcpTools } from "@/app/lib/project_tools"; +import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing"; +import { USE_BILLING } from "@/app/lib/feature_flags"; // get next turn / agent response export async function POST( @@ -29,6 +31,12 @@ export async function POST( } return await authCheck(projectId, req, async () => { + // fetch billing customer id + let billingCustomerId: string | null = null; + if (USE_BILLING) { + billingCustomerId = await getCustomerIdForProject(projectId); + } + // parse and validate the request body let body; try { @@ -74,6 +82,23 @@ export async function POST( return Response.json({ error: "Workflow not found" }, { status: 404 }); } + // check billing authorization + if (USE_BILLING && billingCustomerId) { + const agentModels = workflow.agents.reduce((acc, agent) => { + acc.push(agent.model); + return acc; + }, [] as string[]); + const response = await authorize(billingCustomerId, { + type: 'agent_response', + data: { + agentModels, + }, + }); + if (!response.success) { + return Response.json({ error: response.error || 'Billing error' }, { status: 402 }); + } + } + // if test profile is provided in the request, use it let testProfile: z.infer | null = null; if (result.data.testProfileId) { @@ -112,6 +137,15 @@ export async function POST( const newMessages = convertFromAgenticApiToApiMessages(agenticMessages); const newState = state; + // log billing usage + if (USE_BILLING && billingCustomerId) { + const agentMessageCount = newMessages.filter(m => m.role === 'assistant').length; + await logUsage(billingCustomerId, { + type: 'agent_messages', + amount: agentMessageCount, + }); + } + const responseBody: z.infer = { messages: newMessages, state: newState, diff --git a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts index 5ea47a3b..620a9e07 100644 --- a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts +++ b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts @@ -12,6 +12,8 @@ import { getAgenticApiResponse } from "../../../../../../lib/utils"; import { check_query_limit } from "../../../../../../lib/rate_limiting"; import { PrefixLogger } from "../../../../../../lib/utils"; import { fetchProjectMcpTools } from "@/app/lib/project_tools"; +import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing"; +import { USE_BILLING } from "@/app/lib/feature_flags"; // get next turn / agent response export async function POST( @@ -24,6 +26,12 @@ export async function POST( logger.log(`Processing turn request for chat ${chatId}`); + // fetch billing customer id + let billingCustomerId: string | null = null; + if (USE_BILLING) { + billingCustomerId = await getCustomerIdForProject(session.projectId); + } + // check query limit if (!await check_query_limit(session.projectId)) { logger.log(`Query limit exceeded for project ${session.projectId}`); @@ -93,6 +101,23 @@ export async function POST( throw new Error("Workflow not found"); } + // check billing authorization + if (USE_BILLING && billingCustomerId) { + const agentModels = workflow.agents.reduce((acc, agent) => { + acc.push(agent.model); + return acc; + }, [] as string[]); + const response = await authorize(billingCustomerId, { + type: 'agent_response', + data: { + agentModels, + }, + }); + if (!response.success) { + return Response.json({ error: response.error || 'Billing error' }, { status: 402 }); + } + } + // get assistant response const { agents, tools, prompts, startAgent } = convertWorkflowToAgenticAPI(workflow, projectTools); const unsavedMessages: z.infer[] = [userMessage]; @@ -132,6 +157,15 @@ export async function POST( await chatMessagesCollection.insertMany(unsavedMessages); await chatsCollection.updateOne({ _id: new ObjectId(chatId) }, { $set: { agenticState: state } }); + // log billing usage + if (USE_BILLING && billingCustomerId) { + const agentMessageCount = convertedMessages.filter(m => m.role === 'assistant').length; + await logUsage(billingCustomerId, { + type: 'agent_messages', + amount: agentMessageCount, + }); + } + logger.log(`Turn processing completed successfully`); const lastMessage = unsavedMessages[unsavedMessages.length - 1] as WithId>; return Response.json({ diff --git a/apps/rowboat/app/billing/app.tsx b/apps/rowboat/app/billing/app.tsx new file mode 100644 index 00000000..b79c3eda --- /dev/null +++ b/apps/rowboat/app/billing/app.tsx @@ -0,0 +1,148 @@ +'use client'; + +import { Progress, Badge } from "@heroui/react"; +import { Button } from "@/components/ui/button"; +import { Label } from "@/app/lib/components/label"; +import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types"; +import { z } from "zod"; +import { tokens } from "@/app/styles/design-tokens"; +import { SectionHeading } from "@/components/ui/section-heading"; +import { HorizontalDivider } from "@/components/ui/horizontal-divider"; +import { WithStringId } from "@/app/lib/types/types"; +import clsx from 'clsx'; +import { getCustomerPortalUrl } from "../actions/billing_actions"; + +const planDetails = { + free: { + name: "Free Plan", + color: "default" + }, + starter: { + name: "Starter Plan", + color: "primary" + }, + pro: { + name: "Pro Plan", + color: "secondary" + } +}; + +interface BillingPageProps { + customer: WithStringId>; + usage: z.infer; +} + +export function BillingPage({ customer, usage }: BillingPageProps) { + const plan = customer.subscriptionPlan || "free"; + const isActive = customer.subscriptionActive || false; + const planInfo = planDetails[plan]; + + async function handleManageSubscription() { + const returnUrl = new URL('/billing/callback', window.location.origin); + returnUrl.searchParams.set('redirect', window.location.href); + const url = await getCustomerPortalUrl(returnUrl.toString()); + window.location.href = url; + } + + return ( +
+
+

+ Billing +

+
+ + {/* Subscription Status Panel */} +
+
+ + Current Plan + +
+ +
+
+
+
+

+ {planInfo.name} +

+ + {isActive ? "Active" : "Inactive"} + +
+
+
+ +
+
+
+
+ + {/* Usage Metrics Panel */} +
+
+ + Usage Metrics + +
+ +
+ {Object.entries(usage.usage).map(([type, { usage: used, total }]) => { + const usageType = type as z.infer; + const percentage = Math.min((used / total) * 100, 100); + const isOverLimit = used > total; + + return ( +
+
+
+
+ {isOverLimit && ( + + Over Limit + + )} +
+ +
+ ); + })} +
+
+
+ ); +} \ No newline at end of file diff --git a/apps/rowboat/app/billing/callback/page.tsx b/apps/rowboat/app/billing/callback/page.tsx new file mode 100644 index 00000000..b5b3bf2a --- /dev/null +++ b/apps/rowboat/app/billing/callback/page.tsx @@ -0,0 +1,18 @@ +import { syncWithStripe } from "@/app/lib/billing"; +import { requireBillingCustomer } from '@/app/lib/billing'; +import { redirect } from "next/navigation"; + +export const dynamic = 'force-dynamic'; + +export default async function Page({ + searchParams, +}: { + searchParams: { + redirect: string; + } +}) { + const customer = await requireBillingCustomer(); + await syncWithStripe(customer._id); + const redirectUrl = searchParams.redirect as string; + redirect(redirectUrl || '/projects'); +} \ No newline at end of file diff --git a/apps/rowboat/app/billing/layout.tsx b/apps/rowboat/app/billing/layout.tsx new file mode 100644 index 00000000..3547f9ea --- /dev/null +++ b/apps/rowboat/app/billing/layout.tsx @@ -0,0 +1,13 @@ +import AppLayout from '../projects/layout/components/app-layout'; + +export default function Layout({ + children, +}: Readonly<{ + children: React.ReactNode; +}>) { + return ( + + {children} + + ); +} \ No newline at end of file diff --git a/apps/rowboat/app/billing/page.tsx b/apps/rowboat/app/billing/page.tsx new file mode 100644 index 00000000..974c73f9 --- /dev/null +++ b/apps/rowboat/app/billing/page.tsx @@ -0,0 +1,17 @@ +import { requireBillingCustomer } from '../lib/billing'; +import { BillingPage } from './app'; +import { getUsage } from '../lib/billing'; +import { redirect } from 'next/navigation'; +import { USE_BILLING } from '../lib/feature_flags'; + +export const dynamic = 'force-dynamic'; + +export default async function Page() { + if (!USE_BILLING) { + redirect('/projects'); + } + + const customer = await requireBillingCustomer(); + const usage = await getUsage(customer._id); + return ; +} \ No newline at end of file diff --git a/apps/rowboat/app/lib/auth.ts b/apps/rowboat/app/lib/auth.ts new file mode 100644 index 00000000..009b8e3c --- /dev/null +++ b/apps/rowboat/app/lib/auth.ts @@ -0,0 +1,116 @@ +import { z } from "zod"; +import { Claims } from "@auth0/nextjs-auth0"; +import { ObjectId } from "mongodb"; +import { usersCollection, projectsCollection, projectMembersCollection } from "./mongodb"; +import { getSession } from "@auth0/nextjs-auth0"; +import { User, WithStringId } from "./types/types"; +import { USE_AUTH } from "./feature_flags"; +import { redirect } from "next/navigation"; + +export const GUEST_SESSION: Claims = { + email: "guest@rowboatlabs.com", + email_verified: true, + sub: "guest_user", +} + +export const GUEST_DB_USER: WithStringId> = { + _id: "guest_user", + auth0Id: "guest_user", + name: "Guest", + email: "guest@rowboatlabs.com", + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), +} + +/** + * This function should be used as an initial check in server page components to ensure + * the user is authenticated. It will: + * 1. Check for a valid user session + * 2. Redirect to login if no session exists + * 3. Return the authenticated user + * + * Usage in server components: + * ```ts + * const user = await requireAuth(); + * ``` + */ +export async function requireAuth(): Promise>> { + if (!USE_AUTH) { + return GUEST_DB_USER; + } + + const { user } = await getSession() || {}; + if (!user) { + redirect('/api/auth/login'); + } + + // fetch db user + let dbUser = await getUserFromSessionId(user.sub); + + // if db user does not exist, create one + if (!dbUser) { + // create user record + const doc = { + _id: new ObjectId(), + auth0Id: user.sub, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + email: user.email, + }; + console.log(`creating new user id ${doc._id.toString()} for session id ${user.sub}`); + await usersCollection.insertOne(doc); + + // since auth feature was rolled out later, + // set all project authors to new user id instead + // of user.sub + await updateProjectRefs(user.sub, doc._id.toString()); + + dbUser = { + ...doc, + _id: doc._id.toString(), + }; + } + + const { _id, ...rest } = dbUser; + return { + ...rest, + _id: _id.toString(), + }; +} + +async function updateProjectRefs(sessionUserId: string, dbUserId: string) { + await projectsCollection.updateMany({ + createdByUserId: sessionUserId + }, { + $set: { + createdByUserId: dbUserId, + lastUpdatedAt: new Date().toISOString(), + } + }); + + await projectMembersCollection.updateMany({ + userId: sessionUserId + }, { + $set: { + userId: dbUserId, + } + }); +} + +export async function getUserFromSessionId(sessionUserId: string): Promise> | null> { + if (!USE_AUTH) { + return GUEST_DB_USER; + } + + let dbUser = await usersCollection.findOne({ + auth0Id: sessionUserId + }); + if (!dbUser) { + return null; + } + const { _id, ...rest } = dbUser; + return { + ...rest, + _id: _id.toString(), + }; +} \ No newline at end of file diff --git a/apps/rowboat/app/lib/billing.ts b/apps/rowboat/app/lib/billing.ts new file mode 100644 index 00000000..c23a0706 --- /dev/null +++ b/apps/rowboat/app/lib/billing.ts @@ -0,0 +1,304 @@ +import { WithStringId } from './types/types'; +import { z } from 'zod'; +import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types'; +import { ObjectId } from 'mongodb'; +import { projectsCollection, usersCollection } from './mongodb'; +import { getSession } from '@auth0/nextjs-auth0'; +import { redirect } from 'next/navigation'; +import { getUserFromSessionId, requireAuth } from './auth'; +import { USE_BILLING } from './feature_flags'; + +const BILLING_API_URL = process.env.BILLING_API_URL || 'http://billing'; +const BILLING_API_KEY = process.env.BILLING_API_KEY || 'test'; + +const GUEST_BILLING_CUSTOMER = { + _id: "guest-user", + userId: "guest-user", + name: "Guest", + email: "guest@rowboatlabs.com", + stripeCustomerId: "guest", + subscriptionPlan: "free" as const, + subscriptionActive: true, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), +}; + +export class BillingError extends Error { + constructor(message: string) { + super(message); + this.name = 'BillingError'; + } +} + +export async function getCustomerIdForProject(projectId: string): Promise { + const project = await projectsCollection.findOne({ _id: projectId }); + if (!project) { + throw new Error("Project not found"); + } + const user = await usersCollection.findOne({ _id: new ObjectId(project.createdByUserId) }); + if (!user) { + throw new Error("User not found"); + } + if (!user.billingCustomerId) { + throw new Error("User has no billing customer id"); + } + return user.billingCustomerId; +} + +export async function getBillingCustomer(id: string): Promise> | null> { + const response = await fetch(`${BILLING_API_URL}/api/customers/${id}`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`Failed to fetch billing customer: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = Customer.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse billing customer: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data; +} + +async function createBillingCustomer(userId: string, email: string): Promise>> { + const response = await fetch(`${BILLING_API_URL}/api/customers`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ userId, email }) + }); + if (!response.ok) { + throw new Error(`Failed to create billing customer: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = Customer.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse billing customer: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data as z.infer; +} + +export async function syncWithStripe(customerId: string): Promise { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/sync-with-stripe`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`Failed to sync with stripe: ${response.status} ${response.statusText} ${await response.text()}`); + } +} + +export async function authorize(customerId: string, request: z.infer): Promise> { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/authorize`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + if (!response.ok) { + throw new Error(`Failed to authorize billing: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = AuthorizeResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse authorize billing response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data as z.infer; +} + +export async function logUsage(customerId: string, request: z.infer) { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + if (!response.ok) { + throw new Error(`Failed to log usage: ${response.status} ${response.statusText} ${await response.text()}`); + } +} + +export async function getUsage(customerId: string): Promise> { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/usage`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`Failed to get usage: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = UsageResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse usage response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data as z.infer; +} + +export async function createCustomerPortalSession(customerId: string, returnUrl: string): Promise { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/customer-portal-session`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ returnUrl }) + }); + if (!response.ok) { + throw new Error(`Failed to get customer portal url: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = CustomerPortalSessionResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse customer portal session response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data.url; +} + +export async function getPrices(): Promise> { + const response = await fetch(`${BILLING_API_URL}/api/prices`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`Failed to get prices: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = PricesResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse prices response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data as z.infer; +} + +export async function updateSubscriptionPlan(customerId: string, request: z.infer): Promise { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/update-subscription-plan`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + if (!response.ok) { + throw new Error(`Failed to update subscription plan: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = UpdateSubscriptionPlanResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse update subscription plan response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data.url; +} + +export async function getEligibleModels(customerId: string): Promise> { + const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/models`, { + method: 'GET', + headers: { + 'Authorization': `Bearer ${BILLING_API_KEY}`, + 'Content-Type': 'application/json' + } + }); + if (!response.ok) { + throw new Error(`Failed to get eligible models: ${response.status} ${response.statusText} ${await response.text()}`); + } + const json = await response.json(); + const parseResult = ModelsResponse.safeParse(json); + if (!parseResult.success) { + throw new Error(`Failed to parse eligible models response: ${JSON.stringify(parseResult.error)}`); + } + return parseResult.data as z.infer; +} + +/** + * This function should be used as an initial check in server page components to ensure + * the user has a valid billing customer record. It will: + * 1. Return a guest customer if billing is disabled + * 2. Verify user authentication + * 3. Create/update the user record if needed + * 4. Redirect to onboarding if no billing customer exists + * + * Usage in server components: + * ```ts + * const billingCustomer = await requireBillingCustomer(); + * ``` + */ +export async function requireBillingCustomer(): Promise>> { + const user = await requireAuth(); + + if (!USE_BILLING) { + return { + ...GUEST_BILLING_CUSTOMER, + userId: user._id, + }; + } + + // if user does not have an email, redirect to onboarding + if (!user.email) { + redirect('/onboarding'); + } + + // fetch or create customer + let customer: WithStringId> | null; + if (user.billingCustomerId) { + customer = await getBillingCustomer(user.billingCustomerId); + } else { + customer = await createBillingCustomer(user._id, user.email); + console.log("created billing customer", JSON.stringify({ userId: user._id, customer })); + + // update customer id in db + await usersCollection.updateOne({ + _id: new ObjectId(user._id), + }, { + $set: { + billingCustomerId: customer._id, + updatedAt: new Date().toISOString(), + } + }); + } + if (!customer) { + throw new Error("Failed to fetch or create billing customer"); + } + + return customer; +} + +/** + * This function should be used in server page components to ensure the user has an active + * billing subscription. It will: + * 1. Return a guest customer if billing is disabled + * 2. Verify the user has a valid billing customer record + * 3. Redirect to checkout if the subscription is not active + * + * Usage in server components: + * ```ts + * const billingCustomer = await requireActiveBillingSubscription(); + * ``` + */ +export async function requireActiveBillingSubscription(): Promise>> { + const billingCustomer = await requireBillingCustomer(); + + if (USE_BILLING && !billingCustomer?.subscriptionActive) { + redirect('/billing/checkout'); + } + return billingCustomer; +} + diff --git a/apps/rowboat/app/lib/components/user_button.tsx b/apps/rowboat/app/lib/components/user_button.tsx index a03b5155..ee3846ea 100644 --- a/apps/rowboat/app/lib/components/user_button.tsx +++ b/apps/rowboat/app/lib/components/user_button.tsx @@ -2,8 +2,9 @@ import { useUser } from '@auth0/nextjs-auth0/client'; import { Avatar, Dropdown, DropdownItem, DropdownSection, DropdownTrigger, DropdownMenu } from "@heroui/react"; import { useRouter } from 'next/navigation'; +import Link from 'next/link'; -export function UserButton() { +export function UserButton({ useBilling }: { useBilling?: boolean }) { const router = useRouter(); const { user } = useUser(); if (!user) { @@ -25,9 +26,19 @@ export function UserButton() { if (key === 'logout') { router.push('/api/auth/logout'); } + if (key === 'billing') { + router.push('/billing'); + } }} > + {useBilling ? ( + + Billing + + ) : ( + <> + )} Logout diff --git a/apps/rowboat/app/lib/feature_flags.ts b/apps/rowboat/app/lib/feature_flags.ts index b2f82535..8a530ee2 100644 --- a/apps/rowboat/app/lib/feature_flags.ts +++ b/apps/rowboat/app/lib/feature_flags.ts @@ -5,6 +5,7 @@ export const USE_CHAT_WIDGET = process.env.USE_CHAT_WIDGET === 'true'; export const USE_AUTH = process.env.USE_AUTH === 'true'; export const USE_RAG_S3_UPLOADS = process.env.USE_RAG_S3_UPLOADS === 'true'; export const USE_GEMINI_FILE_PARSING = process.env.USE_GEMINI_FILE_PARSING === 'true'; +export const USE_BILLING = process.env.USE_BILLING === 'true'; // Hardcoded flags export const USE_MULTIPLE_PROJECTS = true; diff --git a/apps/rowboat/app/lib/mongodb.ts b/apps/rowboat/app/lib/mongodb.ts index b6e0ca4e..564a3a44 100644 --- a/apps/rowboat/app/lib/mongodb.ts +++ b/apps/rowboat/app/lib/mongodb.ts @@ -1,5 +1,5 @@ import { MongoClient } from "mongodb"; -import { Webpage } from "./types/types"; +import { User, Webpage } from "./types/types"; import { Workflow } from "./types/workflow_types"; import { ApiKey } from "./types/project_types"; import { ProjectMember } from "./types/project_types"; @@ -31,6 +31,7 @@ export const testResultsCollection = db.collection>(" export const chatsCollection = db.collection>("chats"); export const chatMessagesCollection = db.collection>("chat_messages"); export const twilioConfigsCollection = db.collection>("twilio_configs"); +export const usersCollection = db.collection>("users"); // Create indexes twilioConfigsCollection.createIndexes([ diff --git a/apps/rowboat/app/lib/types/billing_types.ts b/apps/rowboat/app/lib/types/billing_types.ts new file mode 100644 index 00000000..21e1a37e --- /dev/null +++ b/apps/rowboat/app/lib/types/billing_types.ts @@ -0,0 +1,100 @@ +import { z } from "zod"; + +export const SubscriptionPlan = z.enum(["free", "starter", "pro"]); + +export const UsageType = z.enum([ + "copilot_requests", + "agent_messages", + "rag_tokens", +]); + +export const Customer = z.object({ + _id: z.string(), + userId: z.string(), + email: z.string(), + stripeCustomerId: z.string(), + subscriptionPlan: SubscriptionPlan.optional(), + subscriptionActive: z.boolean().optional(), + createdAt: z.string().datetime(), + updatedAt: z.string().datetime(), + subscriptionPlanUpdatedAt: z.string().datetime().optional(), + usage: z.record(UsageType, z.number()).optional(), + usageUpdatedAt: z.string().datetime().optional(), +}); + +export const LogUsageRequest = z.object({ + type: UsageType, + amount: z.number().int().positive(), +}); + +export const AuthorizeRequest = z.discriminatedUnion("type", [ + z.object({ + "type": z.literal("create_project"), + "data": z.object({ + "existingProjectCount": z.number(), + }), + }), + z.object({ + "type": z.literal("enable_hosted_tool_server"), + "data": z.object({ + "existingServerCount": z.number(), + }), + }), + z.object({ + "type": z.literal("process_rag"), + "data": z.object({}), + }), + z.object({ + "type": z.literal("copilot_request"), + "data": z.object({}), + }), + z.object({ + "type": z.literal("agent_response"), + "data": z.object({ + agentModels: z.array(z.string()), + }), + }), +]); + +export const AuthorizeResponse = z.object({ + success: z.boolean(), + error: z.string().optional(), +}); + +export const UsageResponse = z.object({ + usage: z.record(UsageType, z.object({ + usage: z.number(), + total: z.number(), + })), +}); + +export const CustomerPortalSessionRequest = z.object({ + returnUrl: z.string(), +}); + +export const CustomerPortalSessionResponse = z.object({ + url: z.string(), +}); + +export const PricesResponse = z.object({ + prices: z.record(SubscriptionPlan, z.object({ + monthly: z.number(), + })), +}); + +export const UpdateSubscriptionPlanRequest = z.object({ + plan: SubscriptionPlan, + returnUrl: z.string(), +}); + +export const UpdateSubscriptionPlanResponse = z.object({ + url: z.string(), +}); + +export const ModelsResponse = z.object({ + agentModels: z.array(z.object({ + name: z.string(), + eligible: z.boolean(), + plan: SubscriptionPlan, + })), +}); \ No newline at end of file diff --git a/apps/rowboat/app/lib/types/copilot_types.ts b/apps/rowboat/app/lib/types/copilot_types.ts index ea38e00c..22e98f42 100644 --- a/apps/rowboat/app/lib/types/copilot_types.ts +++ b/apps/rowboat/app/lib/types/copilot_types.ts @@ -104,6 +104,7 @@ export const CopilotApiChatContext = z.union([ }), ]); export const CopilotAPIRequest = z.object({ + projectId: z.string(), messages: z.array(CopilotApiMessage), workflow_schema: z.string(), current_workflow_config: z.string(), diff --git a/apps/rowboat/app/lib/types/datasource_types.ts b/apps/rowboat/app/lib/types/datasource_types.ts index fa207daf..9fe72d6f 100644 --- a/apps/rowboat/app/lib/types/datasource_types.ts +++ b/apps/rowboat/app/lib/types/datasource_types.ts @@ -13,6 +13,7 @@ export const DataSource = z.object({ ]).optional(), version: z.number(), error: z.string().optional(), + billingError: z.string().optional(), createdAt: z.string().datetime(), lastUpdatedAt: z.string().datetime().optional(), attempts: z.number(), diff --git a/apps/rowboat/app/lib/types/types.ts b/apps/rowboat/app/lib/types/types.ts index a30c1a04..b3f228a9 100644 --- a/apps/rowboat/app/lib/types/types.ts +++ b/apps/rowboat/app/lib/types/types.ts @@ -76,6 +76,15 @@ export const McpServerResponse = z.object({ error: z.string().nullable(), }); +export const User = z.object({ + auth0Id: z.string(), + billingCustomerId: z.string().optional(), + name: z.string().optional(), + email: z.string().optional(), + createdAt: z.string().datetime(), + updatedAt: z.string().datetime(), +}); + export const PlaygroundChat = z.object({ createdAt: z.string().datetime(), projectId: z.string(), diff --git a/apps/rowboat/app/onboarding/app.tsx b/apps/rowboat/app/onboarding/app.tsx new file mode 100644 index 00000000..69216eba --- /dev/null +++ b/apps/rowboat/app/onboarding/app.tsx @@ -0,0 +1,90 @@ +"use client"; +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { FormStatusButton } from "@/app/lib/components/form-status-button"; +import { useRouter } from "next/navigation"; +import { updateUserEmail } from "../actions/auth_actions"; +import { tokens } from "@/app/styles/design-tokens"; +import { SectionHeading } from "@/components/ui/section-heading"; +import { HorizontalDivider } from "@/components/ui/horizontal-divider"; +import clsx from 'clsx'; + +export default function App() { + const router = useRouter(); + const [email, setEmail] = useState(""); + const [submitted, setSubmitted] = useState(false); + const [error, setError] = useState(""); + + async function handleSubmit(e: React.FormEvent) { + e.preventDefault(); + setError(""); + if (!email.trim()) { + setError("Please enter your email."); + return; + } + setSubmitted(true); + + try { + await updateUserEmail(email); + router.push('/projects'); + } catch (error) { + setError("Failed to update email."); + } + } + + return ( +
+
+

+ Complete your profile +

+
+ +
+
+ + Complete your profile + +
+ +
+
+ setEmail(e.target.value)} + placeholder="you@example.com" + required + /> + {error && ( +
+ {error} +
+ )} +
+
+ +
+
+
+
+ ); +} diff --git a/apps/rowboat/app/onboarding/layout.tsx b/apps/rowboat/app/onboarding/layout.tsx new file mode 100644 index 00000000..3547f9ea --- /dev/null +++ b/apps/rowboat/app/onboarding/layout.tsx @@ -0,0 +1,13 @@ +import AppLayout from '../projects/layout/components/app-layout'; + +export default function Layout({ + children, +}: Readonly<{ + children: React.ReactNode; +}>) { + return ( + + {children} + + ); +} \ No newline at end of file diff --git a/apps/rowboat/app/onboarding/page.tsx b/apps/rowboat/app/onboarding/page.tsx new file mode 100644 index 00000000..1a108ee6 --- /dev/null +++ b/apps/rowboat/app/onboarding/page.tsx @@ -0,0 +1,14 @@ +import { redirect } from "next/navigation"; +import App from "./app"; +import { requireAuth } from "../lib/auth"; +import { USE_AUTH } from "../lib/feature_flags"; + +export const dynamic = 'force-dynamic'; + +export default async function Page() { + if (!USE_AUTH) { + redirect('/projects'); + } + await requireAuth(); + return ; +} \ No newline at end of file diff --git a/apps/rowboat/app/projects/[projectId]/config/page.tsx b/apps/rowboat/app/projects/[projectId]/config/page.tsx index a0577fc3..a85caa05 100644 --- a/apps/rowboat/app/projects/[projectId]/config/page.tsx +++ b/apps/rowboat/app/projects/[projectId]/config/page.tsx @@ -1,18 +1,20 @@ import { Metadata } from "next"; import App from "./app"; import { USE_CHAT_WIDGET } from "@/app/lib/feature_flags"; +import { requireActiveBillingSubscription } from '@/app/lib/billing'; export const metadata: Metadata = { title: "Project config", }; -export default function Page({ +export default async function Page({ params, }: { params: { projectId: string; }; }) { + await requireActiveBillingSubscription(); return | null; @@ -61,6 +62,9 @@ const App = forwardRef<{ handleCopyChat: () => void; handleUserMessage: (message streamingResponse, loading: loadingResponse, error: responseError, + clearError: clearResponseError, + billingError, + clearBillingError, start, cancel } = useCopilot({ @@ -108,6 +112,10 @@ const App = forwardRef<{ handleCopyChat: () => void; handleUserMessage: (message useEffect(() => { if (!messages.length || messages.at(-1)?.role !== 'user') return; + if (responseError) { + return; + } + const currentStart = startRef.current; const currentCancel = cancelRef.current; @@ -122,7 +130,7 @@ const App = forwardRef<{ handleCopyChat: () => void; handleUserMessage: (message }); return () => currentCancel(); - }, [messages]); // Only depend on messages + }, [messages, responseError]); const handleCopyChat = useCallback(() => { if (onCopyJson) { @@ -157,7 +165,15 @@ const App = forwardRef<{ handleCopyChat: () => void; handleUserMessage: (message size="sm" color="danger" onClick={() => { - setMessages(prev => [...prev.slice(0, -1)]); // remove last assistant if needed + // remove the last assistant message, if any + setMessages(prev => { + const lastMessage = prev[prev.length - 1]; + if (lastMessage?.role === 'assistant') { + return prev.slice(0, -1); + } + return prev; + }); + clearResponseError(); }} > Retry @@ -191,6 +207,11 @@ const App = forwardRef<{ handleCopyChat: () => void; handleUserMessage: (message /> + ); }); @@ -215,6 +236,7 @@ export const Copilot = forwardRef<{ handleUserMessage: (message: string) => void const [copilotKey, setCopilotKey] = useState(0); const [showCopySuccess, setShowCopySuccess] = useState(false); const [messages, setMessages] = useState[]>([]); + const [billingError, setBillingError] = useState(null); const appRef = useRef<{ handleCopyChat: () => void; handleUserMessage: (message: string) => void }>(null); function handleNewChat() { @@ -242,64 +264,67 @@ export const Copilot = forwardRef<{ handleUserMessage: (message: string) => void }), []); return ( - -
-
- COPILOT + <> + +
+
+ COPILOT +
+ + +
- - - +
- + } + rightActions={ +
+ +
+ } + > +
+
- } - rightActions={ -
- -
- } - > -
- -
- + + ); }); diff --git a/apps/rowboat/app/projects/[projectId]/copilot/use-copilot.tsx b/apps/rowboat/app/projects/[projectId]/copilot/use-copilot.tsx index f1d4e262..6f907d95 100644 --- a/apps/rowboat/app/projects/[projectId]/copilot/use-copilot.tsx +++ b/apps/rowboat/app/projects/[projectId]/copilot/use-copilot.tsx @@ -16,9 +16,12 @@ interface UseCopilotResult { streamingResponse: string; loading: boolean; error: string | null; + clearError: () => void; + billingError: string | null; + clearBillingError: () => void; start: ( messages: z.infer[], - onDone: (finalResponse: string) => void + onDone: (finalResponse: string) => void, ) => void; cancel: () => void; } @@ -27,13 +30,21 @@ export function useCopilot({ projectId, workflow, context, dataSources }: UseCop const [streamingResponse, setStreamingResponse] = useState(''); const [loading, setLoading] = useState(false); const [error, setError] = useState(null); - + const [billingError, setBillingError] = useState(null); const cancelRef = useRef<() => void>(() => { }); const responseRef = useRef(''); + function clearError() { + setError(null); + } + + function clearBillingError() { + setBillingError(null); + } + const start = useCallback(async ( messages: z.infer[], - onDone: (finalResponse: string) => void + onDone: (finalResponse: string) => void, ) => { if (!messages.length || messages.at(-1)?.role !== 'user') return; @@ -44,6 +55,15 @@ export function useCopilot({ projectId, workflow, context, dataSources }: UseCop try { const res = await getCopilotResponseStream(projectId, messages, workflow, context || null, dataSources); + + // Check for billing error + if ('billingError' in res) { + setLoading(false); + setError(res.billingError); + setBillingError(res.billingError); + return; + } + const eventSource = new EventSource(`/api/copilot-stream-response/${res.streamId}`); eventSource.onmessage = (event) => { @@ -84,6 +104,9 @@ export function useCopilot({ projectId, workflow, context, dataSources }: UseCop streamingResponse, loading, error, + clearError, + billingError, + clearBillingError, start, cancel, }; diff --git a/apps/rowboat/app/projects/[projectId]/entities/agent_config.tsx b/apps/rowboat/app/projects/[projectId]/entities/agent_config.tsx index 4cedd33a..0439d14e 100644 --- a/apps/rowboat/app/projects/[projectId]/entities/agent_config.tsx +++ b/apps/rowboat/app/projects/[projectId]/entities/agent_config.tsx @@ -4,10 +4,10 @@ import { AgenticAPITool } from "../../../lib/types/agents_api_types"; import { WorkflowPrompt, WorkflowAgent, Workflow, WorkflowTool } from "../../../lib/types/workflow_types"; import { DataSource } from "../../../lib/types/datasource_types"; import { z } from "zod"; -import { PlusIcon, Sparkles, X as XIcon, ChevronDown, ChevronRight, Trash2, Maximize2, Minimize2 } from "lucide-react"; +import { PlusIcon, Sparkles, X as XIcon, ChevronDown, ChevronRight, Trash2, Maximize2, Minimize2, StarIcon } from "lucide-react"; import { useState, useEffect, useRef } from "react"; import { usePreviewModal } from "../workflow/preview-modal"; -import { Modal, ModalContent, ModalHeader, ModalBody, ModalFooter, Select, SelectItem } from "@heroui/react"; +import { Modal, ModalContent, ModalHeader, ModalBody, ModalFooter, Select, SelectItem, Chip, SelectSection } from "@heroui/react"; import { PreviewModalProvider } from "../workflow/preview-modal"; import { CopilotMessage } from "@/app/lib/types/copilot_types"; import { getCopilotAgentInstructions } from "@/app/actions/copilot_actions"; @@ -23,6 +23,8 @@ import { USE_TRANSFER_CONTROL_OPTIONS } from "@/app/lib/feature_flags"; import { Input } from "@/components/ui/input"; import { Info } from "lucide-react"; import { useCopilot } from "../copilot/use-copilot"; +import { BillingUpgradeModal } from "@/components/common/billing-upgrade-modal"; +import { ModelsResponse } from "@/app/lib/types/billing_types"; // Common section header styles const sectionHeaderStyles = "text-xs font-medium uppercase tracking-wider text-gray-500 dark:text-gray-400"; @@ -47,6 +49,7 @@ export function AgentConfig({ handleClose, useRag, triggerCopilotChat, + eligibleModels, }: { projectId: string, workflow: z.infer, @@ -61,6 +64,7 @@ export function AgentConfig({ handleClose: () => void, useRag: boolean, triggerCopilotChat: (message: string) => void, + eligibleModels: z.infer | "*", }) { const [isAdvancedConfigOpen, setIsAdvancedConfigOpen] = useState(false); const [showGenerateModal, setShowGenerateModal] = useState(false); @@ -72,7 +76,8 @@ export function AgentConfig({ const [activeTab, setActiveTab] = useState('instructions'); const [showRagCta, setShowRagCta] = useState(false); const [previousRagSources, setPreviousRagSources] = useState([]); - + const [billingError, setBillingError] = useState(null); + const { start: startCopilotChat, } = useCopilot({ @@ -490,8 +495,8 @@ export function AgentConfig({ -
- +
@@ -505,17 +510,67 @@ export function AgentConfig({ By default, the model is set to gpt-4.1, assuming your OpenAI API key is set in PROVIDER_API_KEY and PROVIDER_BASE_URL is not set.
-
+
}
- handleUpdate({ ...agent, model: e.target.value as z.infer['model'] })} className="w-full max-w-64" - /> + />} + {eligibleModels !== "*" && + }
@@ -764,6 +819,12 @@ export function AgentConfig({ }} /> + + setBillingError(null)} + errorMessage={billingError || ''} + />
); @@ -789,6 +850,7 @@ function GenerateInstructionsModal({ const [prompt, setPrompt] = useState(""); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); + const [billingError, setBillingError] = useState(null); const { showPreview } = usePreviewModal(); const textareaRef = useRef(null); @@ -797,6 +859,7 @@ function GenerateInstructionsModal({ setPrompt(""); setIsLoading(false); setError(null); + setBillingError(null); textareaRef.current?.focus(); } }, [isOpen]); @@ -804,6 +867,7 @@ function GenerateInstructionsModal({ const handleGenerate = async () => { setIsLoading(true); setError(null); + setBillingError(null); try { const msgs: z.infer[] = [ { @@ -812,6 +876,12 @@ function GenerateInstructionsModal({ }, ]; const newInstructions = await getCopilotAgentInstructions(projectId, msgs, workflow, agent.name); + if (typeof newInstructions === 'object' && 'billingError' in newInstructions) { + setBillingError(newInstructions.billingError); + setError(newInstructions.billingError); + setIsLoading(false); + return; + } onClose(); @@ -840,59 +910,66 @@ function GenerateInstructionsModal({ }; return ( - - - Generate Instructions - -
- {error && ( -
-

{error}

- { - setError(null); - handleGenerate(); - }} - > - Retry - -
- )} -