diff --git a/apps/rowboat/app/actions/copilot_actions.ts b/apps/rowboat/app/actions/copilot_actions.ts index 2669df58..73e4eaba 100644 --- a/apps/rowboat/app/actions/copilot_actions.ts +++ b/apps/rowboat/app/actions/copilot_actions.ts @@ -15,6 +15,7 @@ import { WithStringId } from "../lib/types/types"; import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot"; import { container } from "@/di/container"; import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface"; +import { UsageTracker } from "../lib/billing"; const usageQuotaPolicy = container.resolve('usageQuotaPolicy'); @@ -32,8 +33,7 @@ export async function getCopilotResponseStream( // Check billing authorization const authResponse = await authorizeUserAction({ - type: 'copilot_request', - data: {}, + type: 'use_credits', }); if (!authResponse.success) { return { billingError: authResponse.error || 'Billing error' }; @@ -75,8 +75,7 @@ export async function getCopilotAgentInstructions( // Check billing authorization const authResponse = await authorizeUserAction({ - type: 'copilot_request', - data: {}, + type: 'use_credits', }); if (!authResponse.success) { return { billingError: authResponse.error || 'Billing error' }; @@ -93,8 +92,11 @@ export async function getCopilotAgentInstructions( } }; + const usageTracker = new UsageTracker(); + // call copilot api const agent_instructions = await getEditAgentInstructionsResponse( + usageTracker, projectId, request.context, request.messages, @@ -104,8 +106,7 @@ export async function getCopilotAgentInstructions( // log the billing usage if (USE_BILLING) { await logUsage({ - type: 'copilot_requests', - amount: 1, + items: usageTracker.flush(), }); } diff --git a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts index f72d4782..10d1fc66 100644 --- a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts +++ b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts @@ -1,4 +1,4 @@ -import { getCustomerIdForProject, logUsage } from "@/app/lib/billing"; +import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing"; import { USE_BILLING } from "@/app/lib/feature_flags"; import { redisClient } from "@/app/lib/redis"; import { CopilotAPIRequest } from "@/app/lib/types/copilot_types"; @@ -21,6 +21,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId: billingCustomerId = await getCustomerIdForProject(projectId); } + const usageTracker = new UsageTracker(); const encoder = new TextEncoder(); let messageCount = 0; @@ -29,6 +30,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId: try { // Iterate over the copilot stream generator for await (const event of streamMultiAgentResponse( + usageTracker, projectId, context, messages, @@ -49,21 +51,20 @@ export async function GET(request: Request, props: { params: Promise<{ streamId: } controller.close(); - - // increment copilot request count in billing + } catch (error) { + console.error('Error processing copilot stream:', error); + controller.error(error); + } finally { + // log copilot usage if (USE_BILLING && billingCustomerId) { try { await logUsage(billingCustomerId, { - type: "copilot_requests", - amount: 1, + items: usageTracker.flush(), }); } catch (error) { console.error("Error logging usage", error); } } - } catch (error) { - console.error('Error processing copilot stream:', error); - controller.error(error); } }, }); diff --git a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts index 1e9a500f..c65d0411 100644 --- a/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts +++ b/apps/rowboat/app/api/widget/v1/chats/[chatId]/turn/route.ts @@ -220,10 +220,10 @@ export async function POST( // log billing usage if (USE_BILLING && billingCustomerId) { const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length; - await logUsage(billingCustomerId, { - type: 'agent_messages', - amount: agentMessageCount, - }); + // await logUsage(billingCustomerId, { + // type: 'agent_messages', + // amount: agentMessageCount, + // }); } logger.log(`Turn processing completed successfully`); diff --git a/apps/rowboat/app/billing/app.tsx b/apps/rowboat/app/billing/app.tsx index 660ca18e..9195089d 100644 --- a/apps/rowboat/app/billing/app.tsx +++ b/apps/rowboat/app/billing/app.tsx @@ -3,7 +3,7 @@ import { Progress, Badge, Chip } from "@heroui/react"; import { Button } from "@/components/ui/button"; import { Label } from "@/app/lib/components/label"; -import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types"; +import { Customer, UsageResponse } from "@/app/lib/types/billing_types"; import { z } from "zod"; import { tokens } from "@/app/styles/design-tokens"; import { SectionHeading } from "@/components/ui/section-heading"; @@ -47,6 +47,15 @@ export function BillingPage({ customer, usage }: BillingPageProps) { const displayStatus = getDisplayStatus(customer.subscriptionStatus); const planInfo = planDetails[plan]; + // Prepare usage metrics data + const usageData = Object.entries(usage.usage) + .map(([type, credits]) => ({ + type, + credits, + totalUsedCredits: usage.sanctionedCredits - usage.availableCredits + })) + .sort((a, b) => b.credits - a.credits); + async function handleManageSubscription() { const returnUrl = new URL('/billing/callback', window.location.origin); returnUrl.searchParams.set('redirect', window.location.href); @@ -109,48 +118,175 @@ export function BillingPage({ customer, usage }: BillingPageProps) { - {/* Usage Metrics Panel */} + {/* Credits Overview Panel */}
- Usage Metrics + Credits Overview
- {Object.entries(usage.usage).map(([type, { usage: used, total }]) => { - const usageType = type as z.infer; - const percentage = Math.min((used / total) * 100, 100); - const isOverLimit = used > total; +
+
+
+
+
+
+
+
+ + {/* Warning for negative credits */} + {usage.availableCredits < 0 && ( +
+

+ ⚠️ You have exceeded your credit limit. Please upgrade your plan or contact support to avoid service interruptions. +

+
+ )} + + {/* Warning for high credit usage (>80%) */} + {usage.availableCredits >= 0 && ((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) > 0.8 && ( +
+

+ ⚠️ You have used more than 80% of your credits. Consider upgrading your plan to avoid interruptions. +

+
+ )} + + {/* Credits Progress Bar */} +
+
+
+ +
+
+
- return ( -
-
-
-
diff --git a/apps/rowboat/app/lib/agent-tools.ts b/apps/rowboat/app/lib/agent-tools.ts index d8659937..b4b82ac0 100644 --- a/apps/rowboat/app/lib/agent-tools.ts +++ b/apps/rowboat/app/lib/agent-tools.ts @@ -16,6 +16,7 @@ import { qdrantClient } from '../lib/qdrant'; import { EmbeddingRecord } from "./types/datasource_types"; import { WorkflowAgent, WorkflowTool } from "./types/workflow_types"; import { PrefixLogger } from "./utils"; +import { UsageTracker } from "./billing"; // Provider configuration const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; @@ -30,6 +31,7 @@ const openai = createOpenAI({ // Helper to handle mock tool responses export async function invokeMockTool( logger: PrefixLogger, + usageTracker: UsageTracker, toolName: string, args: string, description: string, @@ -49,18 +51,28 @@ export async function invokeMockTool( content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.` }]; - const { text } = await generateText({ + const { text, usage } = await generateText({ model: openai(MODEL), messages, }); logger.log(`generated text: ${text}`); + // track usage + usageTracker.track({ + type: "LLM_USAGE", + modelName: MODEL, + inputTokens: usage.promptTokens, + outputTokens: usage.completionTokens, + context: "agents_runtime.mock_tool", + }); + return text; } // Helper to handle RAG tool calls export async function invokeRagTool( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, query: string, sourceIds: string[], @@ -81,11 +93,21 @@ export async function invokeRagTool( logger.log(`k: ${k}`); // Create embedding for question - const { embedding } = await embed({ + const { embedding, usage } = await embed({ model: embeddingModel, value: query, }); + // track usage + + // track usage + usageTracker.track({ + type: "EMBEDDING_MODEL_USAGE", + modelName: embeddingModel.modelId, + tokens: usage.tokens, + context: "agents_runtime.rag_tool.embedding_usage", + }); + // Fetch all data sources for this project const sources = await dataSourcesCollection.find({ projectId: projectId, @@ -154,6 +176,7 @@ export async function invokeRagTool( export async function invokeWebhookTool( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, name: string, input: any, @@ -233,6 +256,7 @@ export async function invokeWebhookTool( // Helper to handle MCP tool calls export async function invokeMcpTool( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, name: string, input: any, @@ -269,6 +293,7 @@ export async function invokeMcpTool( // Helper to handle composio tool calls export async function invokeComposioTool( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, name: string, composioData: z.infer['composioData'] & {}, @@ -299,12 +324,21 @@ export async function invokeComposioTool( connectedAccountId: connectedAccountId, }); logger.log(`composio tool result: ${JSON.stringify(result)}`); + + // track usage + usageTracker.track({ + type: "COMPOSIO_TOOL_USAGE", + toolSlug: slug, + context: "agents_runtime.composio_tool", + }); + return result.data; } // Helper to create RAG tool export function createRagTool( logger: PrefixLogger, + usageTracker: UsageTracker, config: z.infer, projectId: string ): Tool { @@ -321,6 +355,7 @@ export function createRagTool( async execute(input: { query: string }) { const results = await invokeRagTool( logger, + usageTracker, projectId, input.query, config.ragDataSources || [], @@ -337,6 +372,7 @@ export function createRagTool( // Helper to create a mock tool export function createMockTool( logger: PrefixLogger, + usageTracker: UsageTracker, config: z.infer, ): Tool { return tool({ @@ -353,6 +389,7 @@ export function createMockTool( try { const result = await invokeMockTool( logger, + usageTracker, config.name, JSON.stringify(input), config.description, @@ -374,6 +411,7 @@ export function createMockTool( // Helper to create a webhook tool export function createWebhookTool( logger: PrefixLogger, + usageTracker: UsageTracker, config: z.infer, projectId: string, ): Tool { @@ -391,7 +429,7 @@ export function createWebhookTool( }, async execute(input: any) { try { - const result = await invokeWebhookTool(logger, projectId, name, input); + const result = await invokeWebhookTool(logger, usageTracker, projectId, name, input); return JSON.stringify({ result, }); @@ -408,6 +446,7 @@ export function createWebhookTool( // Helper to create an mcp tool export function createMcpTool( logger: PrefixLogger, + usageTracker: UsageTracker, config: z.infer, projectId: string ): Tool { @@ -425,7 +464,7 @@ export function createMcpTool( }, async execute(input: any) { try { - const result = await invokeMcpTool(logger, projectId, name, input, mcpServerName || ''); + const result = await invokeMcpTool(logger, usageTracker, projectId, name, input, mcpServerName || ''); return JSON.stringify({ result, }); @@ -442,6 +481,7 @@ export function createMcpTool( // Helper to create a composio tool export function createComposioTool( logger: PrefixLogger, + usageTracker: UsageTracker, config: z.infer, projectId: string ): Tool { @@ -463,7 +503,7 @@ export function createComposioTool( }, async execute(input: any) { try { - const result = await invokeComposioTool(logger, projectId, name, composioData, input); + const result = await invokeComposioTool(logger, usageTracker, projectId, name, composioData, input); return JSON.stringify({ result, }); @@ -479,6 +519,7 @@ export function createComposioTool( export function createTools( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, workflow: { tools: z.infer[] }, toolConfig: Record>, @@ -492,16 +533,16 @@ export function createTools( toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`); if (config.mockTool) { - tools[toolName] = createMockTool(logger, config); + tools[toolName] = createMockTool(logger, usageTracker, config); toolLogger.log(`✓ created mock tool: ${toolName}`); } else if (config.isMcp) { - tools[toolName] = createMcpTool(logger, config, projectId); + tools[toolName] = createMcpTool(logger, usageTracker, config, projectId); toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`); } else if (config.isComposio) { - tools[toolName] = createComposioTool(logger, config, projectId); + tools[toolName] = createComposioTool(logger, usageTracker, config, projectId); toolLogger.log(`✓ created composio tool: ${toolName}`); } else { - tools[toolName] = createWebhookTool(logger, config, projectId); + tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId); toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`); } } diff --git a/apps/rowboat/app/lib/agents.ts b/apps/rowboat/app/lib/agents.ts index 83f06eb6..a1f78049 100644 --- a/apps/rowboat/app/lib/agents.ts +++ b/apps/rowboat/app/lib/agents.ts @@ -12,6 +12,8 @@ import { ConnectedEntity, sanitizeTextWithMentions, Workflow, WorkflowAgent, Wor import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions"; import { PrefixLogger } from "./utils"; import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types"; +import { UsageTracker } from "./billing"; + // Native handoff support import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs"; import { PipelineStateManager } from "./pipeline-state-manager"; @@ -78,14 +80,6 @@ const openai = createOpenAI({ baseURL: PROVIDER_BASE_URL, }); -const ZUsage = z.object({ - tokens: z.object({ - total: z.number(), - prompt: z.number(), - completion: z.number(), - }), -}); - const ZOutMessage = z.union([ AssistantMessage, AssistantMessageWithToolCalls, @@ -95,6 +89,7 @@ const ZOutMessage = z.union([ // Helper to create an agent function createAgent( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, config: z.infer, tools: Record, @@ -145,7 +140,7 @@ ${CHILD_TRANSFER_RELATED_INSTRUCTIONS} // Add RAG tool if needed if (config.ragDataSources?.length) { - const ragTool = createRagTool(logger, config, projectId); + const ragTool = createRagTool(logger, usageTracker, config, projectId); agentTools.push(ragTool); // update instructions to include RAG instructions @@ -269,8 +264,8 @@ function getStartOfTurnAgentName( // Logs an event and then yields it async function* emitEvent( logger: PrefixLogger, - event: z.infer | z.infer, -): AsyncIterable | z.infer> { + event: z.infer, +): AsyncIterable> { logger.log(`-> emitting event: ${JSON.stringify(event)}`); yield event; return; @@ -321,30 +316,6 @@ class AgentTransferCounter { } } -class UsageTracker { - private usage: { - total: number; - prompt: number; - completion: number; - } = { total: 0, prompt: 0, completion: 0 }; - - increment(total: number, prompt: number, completion: number): void { - this.usage.total += total; - this.usage.prompt += prompt; - this.usage.completion += completion; - } - - get(): { total: number, prompt: number, completion: number } { - return this.usage; - } - - asEvent(): z.infer { - return { - tokens: this.usage, - }; - } -} - function ensureSystemMessage(logger: PrefixLogger, messages: z.infer[]) { logger = logger.child(`ensureSystemMessage`); @@ -396,7 +367,7 @@ function mapConfig(workflow: z.infer): { return { agentConfig, toolConfig, promptConfig, pipelineConfig }; } -async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer): AsyncIterable | z.infer> { +async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer): AsyncIterable> { // find the greeting prompt const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?'; logger.log(`greeting turn: ${prompt}`); @@ -408,15 +379,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer, agentConfig: Record>, @@ -447,6 +416,7 @@ function createAgentsWithNativeHandoffs( const { agent, entities } = createAgent( logger, + usageTracker, projectId, config, tools, @@ -560,6 +530,7 @@ function createAgentsWithNativeHandoffs( // Legacy agent creation (existing implementation) function createAgentsLegacy( logger: PrefixLogger, + usageTracker: UsageTracker, projectId: string, workflow: z.infer, agentConfig: Record>, @@ -595,6 +566,7 @@ function createAgentsLegacy( const { agent, entities } = createAgent( logger, + usageTracker, projectId, config, tools, @@ -763,12 +735,13 @@ function maybeInjectGiveUpControlInstructions( // Handle raw model stream events async function* handleRawModelStreamEvent( event: any, + agentConfig: Record>, agentName: string, turnMsgs: z.infer[], usageTracker: UsageTracker, eventLogger: PrefixLogger, getAgentState?: (agentName: string) => AgentState -): AsyncIterable | z.infer> { +): AsyncIterable> { if (event.data.type === 'response_done') { // Count tool calls (excluding transfer_to_* calls) const toolCallCount = event.data.response.output.filter( @@ -809,12 +782,13 @@ async function* handleRawModelStreamEvent( } // update usage information - usageTracker.increment( - event.data.response.usage.totalTokens, - event.data.response.usage.inputTokens, - event.data.response.usage.outputTokens - ); - eventLogger.log(`updated usage information: ${JSON.stringify(usageTracker.get())}`); + usageTracker.track({ + type: "LLM_USAGE", + modelName: agentConfig[agentName]?.model || "unknown", + inputTokens: event.data.response.usage.inputTokens, + outputTokens: event.data.response.usage.outputTokens, + context: "agents_runtime.llm_usage", + }); } } @@ -833,7 +807,7 @@ async function* handleNativeHandoffEvent( originalHandoffs: Record, eventLogger: PrefixLogger, loopLogger: PrefixLogger -): AsyncIterable | z.infer | { newAgentName: string; shouldContinue?: boolean }> { +): AsyncIterable | { newAgentName: string; shouldContinue?: boolean }> { eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`); // skip if its the same agent @@ -943,7 +917,7 @@ async function* handleHandoffEvent( originalHandoffs: Record, eventLogger: PrefixLogger, loopLogger: PrefixLogger -): AsyncIterable | z.infer | { newAgentName: string }> { +): AsyncIterable | { newAgentName: string }> { eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`); // skip if its the same agent @@ -1010,7 +984,7 @@ async function* handleToolCallResult( event: any, turnMsgs: z.infer[], eventLogger: PrefixLogger -): AsyncIterable | z.infer> { +): AsyncIterable> { const m: z.infer = { role: 'tool', content: event.item.rawItem.output.text, @@ -1039,7 +1013,7 @@ async function* handleMessageOutput( eventLogger: PrefixLogger, loopLogger: PrefixLogger, getAgentState: (agentName: string) => AgentState -): AsyncIterable | z.infer | { newAgentName: string | null; shouldContinue: boolean }> { +): AsyncIterable | { newAgentName: string | null; shouldContinue: boolean }> { // check response visibility - could be an agent or pipeline const agentConfigObj = agentConfig[agentName]; const pipelineConfigObj = pipelineConfig[agentName]; @@ -1243,7 +1217,8 @@ export async function* streamResponse( projectId: string, workflow: z.infer, messages: z.infer[], -): AsyncIterable | z.infer> { + usageTracker: UsageTracker, +): AsyncIterable> { // Divider log for tracking agent loop start console.log('-------------------- AGENT LOOP START --------------------'); // set up logging @@ -1275,11 +1250,11 @@ export async function* streamResponse( logger.log(`initialized stack: ${JSON.stringify(stack)}`); // create tools - const tools = createTools(logger, projectId, workflow, toolConfig); + const tools = createTools(logger, usageTracker, projectId, workflow, toolConfig); // create agents with feature flag support const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy; - const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig); + const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, usageTracker, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig); logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`); @@ -1296,7 +1271,6 @@ export async function* streamResponse( let agentName: string | null = startOfTurnAgentName; // start the turn loop - const usageTracker = new UsageTracker(); const turnMsgs: z.infer[] = [...messages]; // Initialize agent state tracking for tool call completion @@ -1390,7 +1364,7 @@ export async function* streamResponse( switch (event.type) { case 'raw_model_stream_event': - yield* handleRawModelStreamEvent(event, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState); + yield* handleRawModelStreamEvent(event, agentConfig, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState); break; case 'run_item_stream_event': @@ -1523,9 +1497,6 @@ export async function* streamResponse( } } - - // emit usage information - yield* emitEvent(logger, usageTracker.asEvent()); } // this is a sync version of streamResponse @@ -1535,8 +1506,10 @@ export async function getResponse( messages: z.infer[], ): Promise<{ messages: z.infer[], - usage: z.infer, + usage: any, }> { + throw new Error("Not implemented!"); + /* const out: z.infer[] = []; let usage: z.infer = { tokens: { @@ -1554,4 +1527,5 @@ export async function getResponse( } } return { messages: out, usage }; + */ } \ No newline at end of file diff --git a/apps/rowboat/app/lib/billing.ts b/apps/rowboat/app/lib/billing.ts index d6d880ba..e7f8598d 100644 --- a/apps/rowboat/app/lib/billing.ts +++ b/apps/rowboat/app/lib/billing.ts @@ -1,6 +1,6 @@ import { WithStringId } from './types/types'; import { z } from 'zod'; -import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types'; +import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse, UsageItem } from './types/billing_types'; import { ObjectId } from 'mongodb'; import { projectsCollection, usersCollection } from './mongodb'; import { redirect } from 'next/navigation'; @@ -23,6 +23,20 @@ const GUEST_BILLING_CUSTOMER = { updatedAt: new Date().toISOString(), }; +export class UsageTracker{ + private items: z.infer[] = []; + + track(item: z.infer) { + this.items.push(item); + } + + flush(): z.infer[] { + const items = this.items; + this.items = []; + return items; + } +} + export async function getCustomerIdForProject(projectId: string): Promise { const project = await projectsCollection.findOne({ _id: projectId }); if (!project) { @@ -111,6 +125,7 @@ export async function authorize(customerId: string, request: z.infer) { + console.log(`logging billing usage for customer ${customerId}`, JSON.stringify(request)); const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, { method: 'POST', headers: { diff --git a/apps/rowboat/app/lib/copilot/copilot.ts b/apps/rowboat/app/lib/copilot/copilot.ts index d0276ebf..19e201dd 100644 --- a/apps/rowboat/app/lib/copilot/copilot.ts +++ b/apps/rowboat/app/lib/copilot/copilot.ts @@ -13,6 +13,7 @@ import { COPILOT_MULTI_AGENT_EXAMPLE_1 } from "./example_multi_agent_1"; import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow"; import { USE_COMPOSIO_TOOLS } from "../feature_flags"; import { composio, getTool } from "../composio/composio"; +import { UsageTracker } from "../billing"; const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined; @@ -119,7 +120,7 @@ ${JSON.stringify(simplifiedDataSources)} return prompt; } -async function searchRelevantTools(query: string): Promise { +async function searchRelevantTools(usageTracker: UsageTracker, query: string): Promise { const logger = new PrefixLogger("copilot-search-tools"); console.log("🔧 TOOL CALL: searchRelevantTools", { query }); @@ -142,6 +143,13 @@ async function searchRelevantTools(query: string): Promise { return 'No tools found!'; } + // track composio search tool usage + usageTracker.track({ + type: "COMPOSIO_TOOL_USAGE", + toolSlug: "COMPOSIO_SEARCH_TOOLS", + context: "copilot.search_relevant_tools", + }); + // parse results const result = composioToolSearchResponseSchema.safeParse(searchResult.data); if (!result.success) { @@ -208,6 +216,7 @@ function updateLastUserMessage( } export async function getEditAgentInstructionsResponse( + usageTracker: UsageTracker, projectId: string, context: z.infer | null, messages: z.infer[], @@ -232,7 +241,7 @@ export async function getEditAgentInstructionsResponse( system: COPILOT_INSTRUCTIONS_EDIT_AGENT, messages: messages, })); - const { object } = await generateObject({ + const { object, usage } = await generateObject({ model: openai(COPILOT_MODEL), messages: [ { @@ -246,10 +255,20 @@ export async function getEditAgentInstructionsResponse( }), }); + // log usage + usageTracker.track({ + type: "LLM_USAGE", + modelName: COPILOT_MODEL, + inputTokens: usage.promptTokens, + outputTokens: usage.completionTokens, + context: "copilot.llm_usage", + }); + return object.agent_instructions; } export async function* streamMultiAgentResponse( + usageTracker: UsageTracker, projectId: string, context: z.infer | null, messages: z.infer[], @@ -297,7 +316,7 @@ export async function* streamMultiAgentResponse( }), execute: async ({ query }: { query: string }) => { console.log("🎯 AI TOOL CALL: search_relevant_tools", { query }); - const result = await searchRelevantTools(query); + const result = await searchRelevantTools(usageTracker, query); console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", { query, resultLength: result.length @@ -341,6 +360,15 @@ export async function* streamMultiAgentResponse( toolCallId: event.toolCallId, result: event.result, }; + } else if (event.type === "step-finish") { + // log usage + usageTracker.track({ + type: "LLM_USAGE", + modelName: COPILOT_MODEL, + inputTokens: event.usage.promptTokens, + outputTokens: event.usage.completionTokens, + context: "copilot.llm_usage", + }); } } diff --git a/apps/rowboat/app/lib/types/billing_types.ts b/apps/rowboat/app/lib/types/billing_types.ts index f258b65e..76c89577 100644 --- a/apps/rowboat/app/lib/types/billing_types.ts +++ b/apps/rowboat/app/lib/types/billing_types.ts @@ -2,12 +2,52 @@ import { z } from "zod"; export const SubscriptionPlan = z.enum(["free", "starter", "pro"]); -export const UsageType = z.enum([ - "copilot_requests", - "agent_messages", - "rag_tokens", +export const UsageTypeKey = z.enum([ + "LLM_USAGE", + "EMBEDDING_MODEL_USAGE", + "COMPOSIO_TOOL_USAGE", + "FIRECRAWL_SCRAPE_USAGE", ]); +export const LLMUsage = z.object({ + type: z.literal(UsageTypeKey.Enum.LLM_USAGE), + modelName: z.string(), + inputTokens: z.number().positive(), + outputTokens: z.number().positive(), + context: z.string(), +}); + +export const EmbeddingModelUsage = z.object({ + type: z.literal(UsageTypeKey.Enum.EMBEDDING_MODEL_USAGE), + modelName: z.string(), + tokens: z.number().positive(), + context: z.string(), +}); + +export const ComposioToolUsage = z.object({ + type: z.literal(UsageTypeKey.Enum.COMPOSIO_TOOL_USAGE), + toolSlug: z.string(), + context: z.string(), +}); + +export const FirecrawlScrapeUsage = z.object({ + type: z.literal(UsageTypeKey.Enum.FIRECRAWL_SCRAPE_USAGE), + context: z.string(), +}); + +export const UsageItem = z.discriminatedUnion("type", [ + LLMUsage, + EmbeddingModelUsage, + ComposioToolUsage, + FirecrawlScrapeUsage, +]); + +export const LogUsageRequest = z.object({ + items: z.array(UsageItem), +}); + +export const CustomerUsageData = z.record(z.string(), z.number()); + export const Customer = z.object({ _id: z.string(), userId: z.string(), @@ -19,36 +59,23 @@ export const Customer = z.object({ createdAt: z.string().datetime(), updatedAt: z.string().datetime(), subscriptionPlanUpdatedAt: z.string().datetime().optional(), - usage: z.record(UsageType, z.number()).optional(), + usage: CustomerUsageData.optional(), usageUpdatedAt: z.string().datetime().optional(), -}); - -export const LogUsageRequest = z.object({ - type: UsageType, - amount: z.number().int().positive(), -}); + creditsOverride: z.number().optional(), + maxProjectsOverride: z.number().optional(), + agentModelsOverride: z.array(z.string()).optional(), + }); export const AuthorizeRequest = z.discriminatedUnion("type", [ + z.object({ + "type": z.literal("use_credits"), + }), z.object({ "type": z.literal("create_project"), "data": z.object({ "existingProjectCount": z.number(), }), }), - z.object({ - "type": z.literal("enable_hosted_tool_server"), - "data": z.object({ - "existingServerCount": z.number(), - }), - }), - z.object({ - "type": z.literal("process_rag"), - "data": z.object({}), - }), - z.object({ - "type": z.literal("copilot_request"), - "data": z.object({}), - }), z.object({ "type": z.literal("agent_response"), "data": z.object({ @@ -63,10 +90,9 @@ export const AuthorizeResponse = z.object({ }); export const UsageResponse = z.object({ - usage: z.record(UsageType, z.object({ - usage: z.number(), - total: z.number(), - })), + sanctionedCredits: z.number(), + availableCredits: z.number(), + usage: CustomerUsageData, }); export const CustomerPortalSessionRequest = z.object({ diff --git a/apps/rowboat/app/scripts/rag_files_worker.ts b/apps/rowboat/app/scripts/rag_files_worker.ts index 719d8f7b..f310663a 100644 --- a/apps/rowboat/app/scripts/rag_files_worker.ts +++ b/apps/rowboat/app/scripts/rag_files_worker.ts @@ -16,7 +16,7 @@ import crypto from 'crypto'; import path from 'path'; import { createOpenAI } from '@ai-sdk/openai'; import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags'; -import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing'; +import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing'; import { BillingError } from '@/src/entities/errors/common'; const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; @@ -75,7 +75,7 @@ async function retryable(fn: () => Promise, maxAttempts: number = 3): Prom } } -async function runProcessPipeline(_logger: PrefixLogger, job: WithId>, doc: WithId> & { data: { type: "file_local" | "file_s3" } }): Promise { +async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId>, doc: WithId> & { data: { type: "file_local" | "file_s3" } }) { const logger = _logger .child(doc._id.toString()) .child(doc.name); @@ -95,7 +95,7 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId split.pageContent) }); + usageTracker.track({ + type: "EMBEDDING_MODEL_USAGE", + modelName: embeddingModel.modelId, + tokens: usage.tokens, + context: "rag.files.embedding_usage", + }); // store embeddings in qdrant logger.log("Storing embeddings in Qdrant"); @@ -170,8 +190,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId>, doc: WithId>): Promise { @@ -339,8 +357,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId> & { data: { type: "file_local" | "file_s3" } }; + const usageTracker = new UsageTracker(); try { - const usedTokens = await runProcessPipeline(logger, job, ldoc); - - // log usage in billing - if (USE_BILLING && billingCustomerId) { - await logUsage(billingCustomerId, { - type: "rag_tokens", - amount: usedTokens, - }); - } + await runProcessPipeline(logger, usageTracker, job, ldoc); } catch (e: any) { errors = true; logger.log("Error processing doc:", e); @@ -371,6 +381,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId>, doc: WithId>): Promise { +async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId>, doc: WithId>) { const logger = _logger .child(doc._id.toString()) .child(doc.name); @@ -42,6 +42,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId split.pageContent) }); + usageTracker.track({ + type: "EMBEDDING_MODEL_USAGE", + modelName: embeddingModel.modelId, + tokens: usage.tokens, + context: "rag.text.embedding_usage", + }); // store embeddings in qdrant logger.log("Storing embeddings in Qdrant"); @@ -73,8 +79,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId>, doc: WithId>): Promise { @@ -241,8 +245,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId(fn: () => Promise, maxAttempts: number = 3): Prom } } -async function runScrapePipeline(_logger: PrefixLogger, job: WithId>, doc: WithId>): Promise { +async function runScrapePipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId>, doc: WithId>) { const logger = _logger .child(doc._id.toString()) .child(doc.name); @@ -62,6 +62,10 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId split.pageContent) }); + usageTracker.track({ + type: "EMBEDDING_MODEL_USAGE", + modelName: embeddingModel.modelId, + tokens: usage.tokens, + context: "rag.urls.embedding_usage", + }); // store embeddings in qdrant logger.log("Storing embeddings in Qdrant"); @@ -104,8 +114,6 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId>, doc: WithId>): Promise { @@ -273,8 +281,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId { acc.push(agent.model); return acc; @@ -111,46 +126,50 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase { conversation.workflow.mockTools = data.input.mockTools; } + // init usage tracker + const usageTracker = new UsageTracker(); + // call agents runtime and handle generated messages - const outputMessages: z.infer[] = []; - for await (const event of streamResponse(projectId, conversation.workflow, inputMessages)) { - // handle msg events - if ("role" in event) { - // collect generated message - const msg = { - ...event, - timestamp: new Date().toISOString(), - }; - outputMessages.push(msg); + try { + const outputMessages: z.infer[] = []; + for await (const event of streamResponse(projectId, conversation.workflow, inputMessages, usageTracker)) { + // handle msg events + if ("role" in event) { + // collect generated message + const msg = { + ...event, + timestamp: new Date().toISOString(), + }; + outputMessages.push(msg); - // yield event - yield { - type: "message", - data: msg, - }; - } else { - // save turn data - const turn = await this.conversationsRepository.addTurn(data.conversationId, { - reason: data.reason, - input: data.input, - output: outputMessages, - }); - - // yield event - yield { - type: "done", - turn, - conversationId, + // yield event + yield { + type: "message", + data: msg, + }; } } - } - // Log billing usage - if (USE_BILLING && billingCustomerId) { - await logUsage(billingCustomerId, { - type: "agent_messages", - amount: outputMessages.length, + // save turn data + const turn = await this.conversationsRepository.addTurn(data.conversationId, { + reason: data.reason, + input: data.input, + output: outputMessages, }); + + // yield event + yield { + type: "done", + turn, + conversationId, + } + } finally { + // Log billing usage + if (USE_BILLING && billingCustomerId) { + await logUsage(billingCustomerId, { + items: usageTracker.flush(), + }); + } } } } \ No newline at end of file