billing + credits updates (#202)

This commit is contained in:
Ramnique Singh 2025-08-14 19:59:38 +05:30 committed by GitHub
parent 852e02e49e
commit eccfb4748f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 497 additions and 229 deletions

View file

@ -15,6 +15,7 @@ import { WithStringId } from "../lib/types/types";
import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot"; import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot";
import { container } from "@/di/container"; import { container } from "@/di/container";
import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface"; import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface";
import { UsageTracker } from "../lib/billing";
const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy'); const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy');
@ -32,8 +33,7 @@ export async function getCopilotResponseStream(
// Check billing authorization // Check billing authorization
const authResponse = await authorizeUserAction({ const authResponse = await authorizeUserAction({
type: 'copilot_request', type: 'use_credits',
data: {},
}); });
if (!authResponse.success) { if (!authResponse.success) {
return { billingError: authResponse.error || 'Billing error' }; return { billingError: authResponse.error || 'Billing error' };
@ -75,8 +75,7 @@ export async function getCopilotAgentInstructions(
// Check billing authorization // Check billing authorization
const authResponse = await authorizeUserAction({ const authResponse = await authorizeUserAction({
type: 'copilot_request', type: 'use_credits',
data: {},
}); });
if (!authResponse.success) { if (!authResponse.success) {
return { billingError: authResponse.error || 'Billing error' }; return { billingError: authResponse.error || 'Billing error' };
@ -93,8 +92,11 @@ export async function getCopilotAgentInstructions(
} }
}; };
const usageTracker = new UsageTracker();
// call copilot api // call copilot api
const agent_instructions = await getEditAgentInstructionsResponse( const agent_instructions = await getEditAgentInstructionsResponse(
usageTracker,
projectId, projectId,
request.context, request.context,
request.messages, request.messages,
@ -104,8 +106,7 @@ export async function getCopilotAgentInstructions(
// log the billing usage // log the billing usage
if (USE_BILLING) { if (USE_BILLING) {
await logUsage({ await logUsage({
type: 'copilot_requests', items: usageTracker.flush(),
amount: 1,
}); });
} }

View file

@ -1,4 +1,4 @@
import { getCustomerIdForProject, logUsage } from "@/app/lib/billing"; import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
import { USE_BILLING } from "@/app/lib/feature_flags"; import { USE_BILLING } from "@/app/lib/feature_flags";
import { redisClient } from "@/app/lib/redis"; import { redisClient } from "@/app/lib/redis";
import { CopilotAPIRequest } from "@/app/lib/types/copilot_types"; import { CopilotAPIRequest } from "@/app/lib/types/copilot_types";
@ -21,6 +21,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
billingCustomerId = await getCustomerIdForProject(projectId); billingCustomerId = await getCustomerIdForProject(projectId);
} }
const usageTracker = new UsageTracker();
const encoder = new TextEncoder(); const encoder = new TextEncoder();
let messageCount = 0; let messageCount = 0;
@ -29,6 +30,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
try { try {
// Iterate over the copilot stream generator // Iterate over the copilot stream generator
for await (const event of streamMultiAgentResponse( for await (const event of streamMultiAgentResponse(
usageTracker,
projectId, projectId,
context, context,
messages, messages,
@ -49,21 +51,20 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
} }
controller.close(); controller.close();
} catch (error) {
// increment copilot request count in billing console.error('Error processing copilot stream:', error);
controller.error(error);
} finally {
// log copilot usage
if (USE_BILLING && billingCustomerId) { if (USE_BILLING && billingCustomerId) {
try { try {
await logUsage(billingCustomerId, { await logUsage(billingCustomerId, {
type: "copilot_requests", items: usageTracker.flush(),
amount: 1,
}); });
} catch (error) { } catch (error) {
console.error("Error logging usage", error); console.error("Error logging usage", error);
} }
} }
} catch (error) {
console.error('Error processing copilot stream:', error);
controller.error(error);
} }
}, },
}); });

View file

@ -220,10 +220,10 @@ export async function POST(
// log billing usage // log billing usage
if (USE_BILLING && billingCustomerId) { if (USE_BILLING && billingCustomerId) {
const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length; const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length;
await logUsage(billingCustomerId, { // await logUsage(billingCustomerId, {
type: 'agent_messages', // type: 'agent_messages',
amount: agentMessageCount, // amount: agentMessageCount,
}); // });
} }
logger.log(`Turn processing completed successfully`); logger.log(`Turn processing completed successfully`);

View file

@ -3,7 +3,7 @@
import { Progress, Badge, Chip } from "@heroui/react"; import { Progress, Badge, Chip } from "@heroui/react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Label } from "@/app/lib/components/label"; import { Label } from "@/app/lib/components/label";
import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types"; import { Customer, UsageResponse } from "@/app/lib/types/billing_types";
import { z } from "zod"; import { z } from "zod";
import { tokens } from "@/app/styles/design-tokens"; import { tokens } from "@/app/styles/design-tokens";
import { SectionHeading } from "@/components/ui/section-heading"; import { SectionHeading } from "@/components/ui/section-heading";
@ -47,6 +47,15 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
const displayStatus = getDisplayStatus(customer.subscriptionStatus); const displayStatus = getDisplayStatus(customer.subscriptionStatus);
const planInfo = planDetails[plan]; const planInfo = planDetails[plan];
// Prepare usage metrics data
const usageData = Object.entries(usage.usage)
.map(([type, credits]) => ({
type,
credits,
totalUsedCredits: usage.sanctionedCredits - usage.availableCredits
}))
.sort((a, b) => b.credits - a.credits);
async function handleManageSubscription() { async function handleManageSubscription() {
const returnUrl = new URL('/billing/callback', window.location.origin); const returnUrl = new URL('/billing/callback', window.location.origin);
returnUrl.searchParams.set('redirect', window.location.href); returnUrl.searchParams.set('redirect', window.location.href);
@ -109,48 +118,175 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
</div> </div>
</section> </section>
{/* Usage Metrics Panel */} {/* Credits Overview Panel */}
<section className="card"> <section className="card">
<div className="px-4 pt-4 pb-6"> <div className="px-4 pt-4 pb-6">
<SectionHeading> <SectionHeading>
Usage Metrics Credits Overview
</SectionHeading> </SectionHeading>
</div> </div>
<HorizontalDivider /> <HorizontalDivider />
<div className="p-6 space-y-6"> <div className="p-6 space-y-6">
{Object.entries(usage.usage).map(([type, { usage: used, total }]) => { <div className="grid grid-cols-1 md:grid-cols-3 gap-6">
const usageType = type as z.infer<typeof UsageType>; <div className="space-y-2">
const percentage = Math.min((used / total) * 100, 100); <Label label="Sanctioned Credits" />
const isOverLimit = used > total; <p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)}>
{usage.sanctionedCredits.toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Total credits allocated to your plan
</p>
</div>
<div className="space-y-2">
<Label label="Used Credits" />
<p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)}>
{(usage.sanctionedCredits - usage.availableCredits).toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Credits consumed so far
</p>
</div>
<div className="space-y-2">
<Label label="Available Credits" />
<p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
usage.availableCredits < 0 ? "text-red-500" : clsx(
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)
)}>
{usage.availableCredits.toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Credits remaining for use
</p>
</div>
</div>
return ( {/* Warning for negative credits */}
<div key={type} className="space-y-2"> {usage.availableCredits < 0 && (
<div className="flex justify-between items-center"> <div className="p-4 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg">
<div className="space-y-1"> <p className={clsx(
<Label label={type.replace(/_/g, ' ')} /> tokens.typography.sizes.sm,
<p className={clsx( "text-red-700 dark:text-red-300"
)}>
You have exceeded your credit limit. Please upgrade your plan or contact support to avoid service interruptions.
</p>
</div>
)}
{/* Warning for high credit usage (>80%) */}
{usage.availableCredits >= 0 && ((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) > 0.8 && (
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
<p className={clsx(
tokens.typography.sizes.sm,
"text-yellow-700 dark:text-yellow-300"
)}>
You have used more than 80% of your credits. Consider upgrading your plan to avoid interruptions.
</p>
</div>
)}
{/* Credits Progress Bar */}
<div className="space-y-2">
<div className="flex justify-between items-center">
<Label label="Credits Usage" />
<span className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
{Math.round(((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100)}%
</span>
</div>
<Progress
size="lg"
value={((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100}
color={usage.availableCredits < 0 ? "danger" : "primary"}
className="h-4"
aria-label="Credits usage"
/>
</div>
</div>
</section>
{/* Usage Metrics Panel */}
<section className="card">
<div className="px-4 pt-4 pb-6">
<SectionHeading>
Usage data
</SectionHeading>
</div>
<HorizontalDivider />
<div className="p-6 space-y-6">
{usageData.length === 0 ? (
<div className="text-center py-8">
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
No usage data yet
</p>
</div>
) : (
usageData.map(({ type, credits, totalUsedCredits }) => {
const percentage = totalUsedCredits > 0 ? (credits / totalUsedCredits) * 100 : 0;
return (
<div key={type} className="space-y-2">
<div className="flex justify-between items-center">
<div className="space-y-1">
<Label label={type.replace(/_/g, ' ')} />
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
{credits.toLocaleString()} credits
</p>
</div>
<span className={clsx(
tokens.typography.sizes.sm, tokens.typography.sizes.sm,
tokens.colors.light.text.secondary, tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary tokens.colors.dark.text.secondary
)}> )}>
{used.toLocaleString()} / {total.toLocaleString()} {Math.round(percentage)}%
</p> </span>
</div> </div>
{isOverLimit && ( <Progress
<Badge color="danger" variant="flat"> value={percentage}
Over Limit color="default"
</Badge> className="h-2"
)} aria-label={`${type} credits usage`}
/>
</div> </div>
<Progress );
value={percentage} })
color={isOverLimit ? "danger" : "primary"} )}
className="h-2"
aria-label={`${type} usage`}
/>
</div>
);
})}
</div> </div>
</section> </section>
</div> </div>

View file

@ -16,6 +16,7 @@ import { qdrantClient } from '../lib/qdrant';
import { EmbeddingRecord } from "./types/datasource_types"; import { EmbeddingRecord } from "./types/datasource_types";
import { WorkflowAgent, WorkflowTool } from "./types/workflow_types"; import { WorkflowAgent, WorkflowTool } from "./types/workflow_types";
import { PrefixLogger } from "./utils"; import { PrefixLogger } from "./utils";
import { UsageTracker } from "./billing";
// Provider configuration // Provider configuration
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
@ -30,6 +31,7 @@ const openai = createOpenAI({
// Helper to handle mock tool responses // Helper to handle mock tool responses
export async function invokeMockTool( export async function invokeMockTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
toolName: string, toolName: string,
args: string, args: string,
description: string, description: string,
@ -49,18 +51,28 @@ export async function invokeMockTool(
content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.` content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.`
}]; }];
const { text } = await generateText({ const { text, usage } = await generateText({
model: openai(MODEL), model: openai(MODEL),
messages, messages,
}); });
logger.log(`generated text: ${text}`); logger.log(`generated text: ${text}`);
// track usage
usageTracker.track({
type: "LLM_USAGE",
modelName: MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "agents_runtime.mock_tool",
});
return text; return text;
} }
// Helper to handle RAG tool calls // Helper to handle RAG tool calls
export async function invokeRagTool( export async function invokeRagTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
query: string, query: string,
sourceIds: string[], sourceIds: string[],
@ -81,11 +93,21 @@ export async function invokeRagTool(
logger.log(`k: ${k}`); logger.log(`k: ${k}`);
// Create embedding for question // Create embedding for question
const { embedding } = await embed({ const { embedding, usage } = await embed({
model: embeddingModel, model: embeddingModel,
value: query, value: query,
}); });
// track usage
// track usage
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "agents_runtime.rag_tool.embedding_usage",
});
// Fetch all data sources for this project // Fetch all data sources for this project
const sources = await dataSourcesCollection.find({ const sources = await dataSourcesCollection.find({
projectId: projectId, projectId: projectId,
@ -154,6 +176,7 @@ export async function invokeRagTool(
export async function invokeWebhookTool( export async function invokeWebhookTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
name: string, name: string,
input: any, input: any,
@ -233,6 +256,7 @@ export async function invokeWebhookTool(
// Helper to handle MCP tool calls // Helper to handle MCP tool calls
export async function invokeMcpTool( export async function invokeMcpTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
name: string, name: string,
input: any, input: any,
@ -269,6 +293,7 @@ export async function invokeMcpTool(
// Helper to handle composio tool calls // Helper to handle composio tool calls
export async function invokeComposioTool( export async function invokeComposioTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
name: string, name: string,
composioData: z.infer<typeof WorkflowTool>['composioData'] & {}, composioData: z.infer<typeof WorkflowTool>['composioData'] & {},
@ -299,12 +324,21 @@ export async function invokeComposioTool(
connectedAccountId: connectedAccountId, connectedAccountId: connectedAccountId,
}); });
logger.log(`composio tool result: ${JSON.stringify(result)}`); logger.log(`composio tool result: ${JSON.stringify(result)}`);
// track usage
usageTracker.track({
type: "COMPOSIO_TOOL_USAGE",
toolSlug: slug,
context: "agents_runtime.composio_tool",
});
return result.data; return result.data;
} }
// Helper to create RAG tool // Helper to create RAG tool
export function createRagTool( export function createRagTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowAgent>, config: z.infer<typeof WorkflowAgent>,
projectId: string projectId: string
): Tool { ): Tool {
@ -321,6 +355,7 @@ export function createRagTool(
async execute(input: { query: string }) { async execute(input: { query: string }) {
const results = await invokeRagTool( const results = await invokeRagTool(
logger, logger,
usageTracker,
projectId, projectId,
input.query, input.query,
config.ragDataSources || [], config.ragDataSources || [],
@ -337,6 +372,7 @@ export function createRagTool(
// Helper to create a mock tool // Helper to create a mock tool
export function createMockTool( export function createMockTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>, config: z.infer<typeof WorkflowTool>,
): Tool { ): Tool {
return tool({ return tool({
@ -353,6 +389,7 @@ export function createMockTool(
try { try {
const result = await invokeMockTool( const result = await invokeMockTool(
logger, logger,
usageTracker,
config.name, config.name,
JSON.stringify(input), JSON.stringify(input),
config.description, config.description,
@ -374,6 +411,7 @@ export function createMockTool(
// Helper to create a webhook tool // Helper to create a webhook tool
export function createWebhookTool( export function createWebhookTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>, config: z.infer<typeof WorkflowTool>,
projectId: string, projectId: string,
): Tool { ): Tool {
@ -391,7 +429,7 @@ export function createWebhookTool(
}, },
async execute(input: any) { async execute(input: any) {
try { try {
const result = await invokeWebhookTool(logger, projectId, name, input); const result = await invokeWebhookTool(logger, usageTracker, projectId, name, input);
return JSON.stringify({ return JSON.stringify({
result, result,
}); });
@ -408,6 +446,7 @@ export function createWebhookTool(
// Helper to create an mcp tool // Helper to create an mcp tool
export function createMcpTool( export function createMcpTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>, config: z.infer<typeof WorkflowTool>,
projectId: string projectId: string
): Tool { ): Tool {
@ -425,7 +464,7 @@ export function createMcpTool(
}, },
async execute(input: any) { async execute(input: any) {
try { try {
const result = await invokeMcpTool(logger, projectId, name, input, mcpServerName || ''); const result = await invokeMcpTool(logger, usageTracker, projectId, name, input, mcpServerName || '');
return JSON.stringify({ return JSON.stringify({
result, result,
}); });
@ -442,6 +481,7 @@ export function createMcpTool(
// Helper to create a composio tool // Helper to create a composio tool
export function createComposioTool( export function createComposioTool(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>, config: z.infer<typeof WorkflowTool>,
projectId: string projectId: string
): Tool { ): Tool {
@ -463,7 +503,7 @@ export function createComposioTool(
}, },
async execute(input: any) { async execute(input: any) {
try { try {
const result = await invokeComposioTool(logger, projectId, name, composioData, input); const result = await invokeComposioTool(logger, usageTracker, projectId, name, composioData, input);
return JSON.stringify({ return JSON.stringify({
result, result,
}); });
@ -479,6 +519,7 @@ export function createComposioTool(
export function createTools( export function createTools(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
workflow: { tools: z.infer<typeof WorkflowTool>[] }, workflow: { tools: z.infer<typeof WorkflowTool>[] },
toolConfig: Record<string, z.infer<typeof WorkflowTool>>, toolConfig: Record<string, z.infer<typeof WorkflowTool>>,
@ -492,16 +533,16 @@ export function createTools(
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`); toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
if (config.mockTool) { if (config.mockTool) {
tools[toolName] = createMockTool(logger, config); tools[toolName] = createMockTool(logger, usageTracker, config);
toolLogger.log(`✓ created mock tool: ${toolName}`); toolLogger.log(`✓ created mock tool: ${toolName}`);
} else if (config.isMcp) { } else if (config.isMcp) {
tools[toolName] = createMcpTool(logger, config, projectId); tools[toolName] = createMcpTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`); toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`);
} else if (config.isComposio) { } else if (config.isComposio) {
tools[toolName] = createComposioTool(logger, config, projectId); tools[toolName] = createComposioTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created composio tool: ${toolName}`); toolLogger.log(`✓ created composio tool: ${toolName}`);
} else { } else {
tools[toolName] = createWebhookTool(logger, config, projectId); tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`); toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
} }
} }

View file

@ -12,6 +12,8 @@ import { ConnectedEntity, sanitizeTextWithMentions, Workflow, WorkflowAgent, Wor
import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions"; import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions";
import { PrefixLogger } from "./utils"; import { PrefixLogger } from "./utils";
import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types"; import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types";
import { UsageTracker } from "./billing";
// Native handoff support // Native handoff support
import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs"; import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs";
import { PipelineStateManager } from "./pipeline-state-manager"; import { PipelineStateManager } from "./pipeline-state-manager";
@ -78,14 +80,6 @@ const openai = createOpenAI({
baseURL: PROVIDER_BASE_URL, baseURL: PROVIDER_BASE_URL,
}); });
const ZUsage = z.object({
tokens: z.object({
total: z.number(),
prompt: z.number(),
completion: z.number(),
}),
});
const ZOutMessage = z.union([ const ZOutMessage = z.union([
AssistantMessage, AssistantMessage,
AssistantMessageWithToolCalls, AssistantMessageWithToolCalls,
@ -95,6 +89,7 @@ const ZOutMessage = z.union([
// Helper to create an agent // Helper to create an agent
function createAgent( function createAgent(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
config: z.infer<typeof WorkflowAgent>, config: z.infer<typeof WorkflowAgent>,
tools: Record<string, Tool>, tools: Record<string, Tool>,
@ -145,7 +140,7 @@ ${CHILD_TRANSFER_RELATED_INSTRUCTIONS}
// Add RAG tool if needed // Add RAG tool if needed
if (config.ragDataSources?.length) { if (config.ragDataSources?.length) {
const ragTool = createRagTool(logger, config, projectId); const ragTool = createRagTool(logger, usageTracker, config, projectId);
agentTools.push(ragTool); agentTools.push(ragTool);
// update instructions to include RAG instructions // update instructions to include RAG instructions
@ -269,8 +264,8 @@ function getStartOfTurnAgentName(
// Logs an event and then yields it // Logs an event and then yields it
async function* emitEvent( async function* emitEvent(
logger: PrefixLogger, logger: PrefixLogger,
event: z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>, event: z.infer<typeof ZOutMessage>,
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> { ): AsyncIterable<z.infer<typeof ZOutMessage>> {
logger.log(`-> emitting event: ${JSON.stringify(event)}`); logger.log(`-> emitting event: ${JSON.stringify(event)}`);
yield event; yield event;
return; return;
@ -321,30 +316,6 @@ class AgentTransferCounter {
} }
} }
class UsageTracker {
private usage: {
total: number;
prompt: number;
completion: number;
} = { total: 0, prompt: 0, completion: 0 };
increment(total: number, prompt: number, completion: number): void {
this.usage.total += total;
this.usage.prompt += prompt;
this.usage.completion += completion;
}
get(): { total: number, prompt: number, completion: number } {
return this.usage;
}
asEvent(): z.infer<typeof ZUsage> {
return {
tokens: this.usage,
};
}
}
function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) { function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) {
logger = logger.child(`ensureSystemMessage`); logger = logger.child(`ensureSystemMessage`);
@ -396,7 +367,7 @@ function mapConfig(workflow: z.infer<typeof Workflow>): {
return { agentConfig, toolConfig, promptConfig, pipelineConfig }; return { agentConfig, toolConfig, promptConfig, pipelineConfig };
} }
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> { async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage>> {
// find the greeting prompt // find the greeting prompt
const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?'; const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?';
logger.log(`greeting turn: ${prompt}`); logger.log(`greeting turn: ${prompt}`);
@ -408,15 +379,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof
agentName: workflow.startAgent, agentName: workflow.startAgent,
responseType: 'external', responseType: 'external',
}); });
// emit final usage information
yield* emitEvent(logger, new UsageTracker().asEvent());
} }
// Enhanced agent creation with native handoff support // Enhanced agent creation with native handoff support
function createAgentsWithNativeHandoffs( function createAgentsWithNativeHandoffs(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
workflow: z.infer<typeof Workflow>, workflow: z.infer<typeof Workflow>,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>, agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
@ -447,6 +416,7 @@ function createAgentsWithNativeHandoffs(
const { agent, entities } = createAgent( const { agent, entities } = createAgent(
logger, logger,
usageTracker,
projectId, projectId,
config, config,
tools, tools,
@ -560,6 +530,7 @@ function createAgentsWithNativeHandoffs(
// Legacy agent creation (existing implementation) // Legacy agent creation (existing implementation)
function createAgentsLegacy( function createAgentsLegacy(
logger: PrefixLogger, logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string, projectId: string,
workflow: z.infer<typeof Workflow>, workflow: z.infer<typeof Workflow>,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>, agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
@ -595,6 +566,7 @@ function createAgentsLegacy(
const { agent, entities } = createAgent( const { agent, entities } = createAgent(
logger, logger,
usageTracker,
projectId, projectId,
config, config,
tools, tools,
@ -763,12 +735,13 @@ function maybeInjectGiveUpControlInstructions(
// Handle raw model stream events // Handle raw model stream events
async function* handleRawModelStreamEvent( async function* handleRawModelStreamEvent(
event: any, event: any,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
agentName: string, agentName: string,
turnMsgs: z.infer<typeof Message>[], turnMsgs: z.infer<typeof Message>[],
usageTracker: UsageTracker, usageTracker: UsageTracker,
eventLogger: PrefixLogger, eventLogger: PrefixLogger,
getAgentState?: (agentName: string) => AgentState getAgentState?: (agentName: string) => AgentState
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> { ): AsyncIterable<z.infer<typeof ZOutMessage>> {
if (event.data.type === 'response_done') { if (event.data.type === 'response_done') {
// Count tool calls (excluding transfer_to_* calls) // Count tool calls (excluding transfer_to_* calls)
const toolCallCount = event.data.response.output.filter( const toolCallCount = event.data.response.output.filter(
@ -809,12 +782,13 @@ async function* handleRawModelStreamEvent(
} }
// update usage information // update usage information
usageTracker.increment( usageTracker.track({
event.data.response.usage.totalTokens, type: "LLM_USAGE",
event.data.response.usage.inputTokens, modelName: agentConfig[agentName]?.model || "unknown",
event.data.response.usage.outputTokens inputTokens: event.data.response.usage.inputTokens,
); outputTokens: event.data.response.usage.outputTokens,
eventLogger.log(`updated usage information: ${JSON.stringify(usageTracker.get())}`); context: "agents_runtime.llm_usage",
});
} }
} }
@ -833,7 +807,7 @@ async function* handleNativeHandoffEvent(
originalHandoffs: Record<string, any[]>, originalHandoffs: Record<string, any[]>,
eventLogger: PrefixLogger, eventLogger: PrefixLogger,
loopLogger: PrefixLogger loopLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string; shouldContinue?: boolean }> { ): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string; shouldContinue?: boolean }> {
eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`); eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
// skip if its the same agent // skip if its the same agent
@ -943,7 +917,7 @@ async function* handleHandoffEvent(
originalHandoffs: Record<string, Agent[]>, originalHandoffs: Record<string, Agent[]>,
eventLogger: PrefixLogger, eventLogger: PrefixLogger,
loopLogger: PrefixLogger loopLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string }> { ): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string }> {
eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`); eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
// skip if its the same agent // skip if its the same agent
@ -1010,7 +984,7 @@ async function* handleToolCallResult(
event: any, event: any,
turnMsgs: z.infer<typeof Message>[], turnMsgs: z.infer<typeof Message>[],
eventLogger: PrefixLogger eventLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> { ): AsyncIterable<z.infer<typeof ZOutMessage>> {
const m: z.infer<typeof Message> = { const m: z.infer<typeof Message> = {
role: 'tool', role: 'tool',
content: event.item.rawItem.output.text, content: event.item.rawItem.output.text,
@ -1039,7 +1013,7 @@ async function* handleMessageOutput(
eventLogger: PrefixLogger, eventLogger: PrefixLogger,
loopLogger: PrefixLogger, loopLogger: PrefixLogger,
getAgentState: (agentName: string) => AgentState getAgentState: (agentName: string) => AgentState
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string | null; shouldContinue: boolean }> { ): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string | null; shouldContinue: boolean }> {
// check response visibility - could be an agent or pipeline // check response visibility - could be an agent or pipeline
const agentConfigObj = agentConfig[agentName]; const agentConfigObj = agentConfig[agentName];
const pipelineConfigObj = pipelineConfig[agentName]; const pipelineConfigObj = pipelineConfig[agentName];
@ -1243,7 +1217,8 @@ export async function* streamResponse(
projectId: string, projectId: string,
workflow: z.infer<typeof Workflow>, workflow: z.infer<typeof Workflow>,
messages: z.infer<typeof Message>[], messages: z.infer<typeof Message>[],
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> { usageTracker: UsageTracker,
): AsyncIterable<z.infer<typeof ZOutMessage>> {
// Divider log for tracking agent loop start // Divider log for tracking agent loop start
console.log('-------------------- AGENT LOOP START --------------------'); console.log('-------------------- AGENT LOOP START --------------------');
// set up logging // set up logging
@ -1275,11 +1250,11 @@ export async function* streamResponse(
logger.log(`initialized stack: ${JSON.stringify(stack)}`); logger.log(`initialized stack: ${JSON.stringify(stack)}`);
// create tools // create tools
const tools = createTools(logger, projectId, workflow, toolConfig); const tools = createTools(logger, usageTracker, projectId, workflow, toolConfig);
// create agents with feature flag support // create agents with feature flag support
const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy; const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy;
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig); const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, usageTracker, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`); logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`);
@ -1296,7 +1271,6 @@ export async function* streamResponse(
let agentName: string | null = startOfTurnAgentName; let agentName: string | null = startOfTurnAgentName;
// start the turn loop // start the turn loop
const usageTracker = new UsageTracker();
const turnMsgs: z.infer<typeof Message>[] = [...messages]; const turnMsgs: z.infer<typeof Message>[] = [...messages];
// Initialize agent state tracking for tool call completion // Initialize agent state tracking for tool call completion
@ -1390,7 +1364,7 @@ export async function* streamResponse(
switch (event.type) { switch (event.type) {
case 'raw_model_stream_event': case 'raw_model_stream_event':
yield* handleRawModelStreamEvent(event, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState); yield* handleRawModelStreamEvent(event, agentConfig, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
break; break;
case 'run_item_stream_event': case 'run_item_stream_event':
@ -1523,9 +1497,6 @@ export async function* streamResponse(
} }
} }
// emit usage information
yield* emitEvent(logger, usageTracker.asEvent());
} }
// this is a sync version of streamResponse // this is a sync version of streamResponse
@ -1535,8 +1506,10 @@ export async function getResponse(
messages: z.infer<typeof Message>[], messages: z.infer<typeof Message>[],
): Promise<{ ): Promise<{
messages: z.infer<typeof ZOutMessage>[], messages: z.infer<typeof ZOutMessage>[],
usage: z.infer<typeof ZUsage>, usage: any,
}> { }> {
throw new Error("Not implemented!");
/*
const out: z.infer<typeof ZOutMessage>[] = []; const out: z.infer<typeof ZOutMessage>[] = [];
let usage: z.infer<typeof ZUsage> = { let usage: z.infer<typeof ZUsage> = {
tokens: { tokens: {
@ -1554,4 +1527,5 @@ export async function getResponse(
} }
} }
return { messages: out, usage }; return { messages: out, usage };
*/
} }

View file

@ -1,6 +1,6 @@
import { WithStringId } from './types/types'; import { WithStringId } from './types/types';
import { z } from 'zod'; import { z } from 'zod';
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types'; import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse, UsageItem } from './types/billing_types';
import { ObjectId } from 'mongodb'; import { ObjectId } from 'mongodb';
import { projectsCollection, usersCollection } from './mongodb'; import { projectsCollection, usersCollection } from './mongodb';
import { redirect } from 'next/navigation'; import { redirect } from 'next/navigation';
@ -23,6 +23,20 @@ const GUEST_BILLING_CUSTOMER = {
updatedAt: new Date().toISOString(), updatedAt: new Date().toISOString(),
}; };
export class UsageTracker{
private items: z.infer<typeof UsageItem>[] = [];
track(item: z.infer<typeof UsageItem>) {
this.items.push(item);
}
flush(): z.infer<typeof UsageItem>[] {
const items = this.items;
this.items = [];
return items;
}
}
export async function getCustomerIdForProject(projectId: string): Promise<string> { export async function getCustomerIdForProject(projectId: string): Promise<string> {
const project = await projectsCollection.findOne({ _id: projectId }); const project = await projectsCollection.findOne({ _id: projectId });
if (!project) { if (!project) {
@ -111,6 +125,7 @@ export async function authorize(customerId: string, request: z.infer<typeof Auth
} }
export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) { export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) {
console.log(`logging billing usage for customer ${customerId}`, JSON.stringify(request));
const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, { const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, {
method: 'POST', method: 'POST',
headers: { headers: {

View file

@ -13,6 +13,7 @@ import { COPILOT_MULTI_AGENT_EXAMPLE_1 } from "./example_multi_agent_1";
import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow"; import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow";
import { USE_COMPOSIO_TOOLS } from "../feature_flags"; import { USE_COMPOSIO_TOOLS } from "../feature_flags";
import { composio, getTool } from "../composio/composio"; import { composio, getTool } from "../composio/composio";
import { UsageTracker } from "../billing";
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined; const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined;
@ -119,7 +120,7 @@ ${JSON.stringify(simplifiedDataSources)}
return prompt; return prompt;
} }
async function searchRelevantTools(query: string): Promise<string> { async function searchRelevantTools(usageTracker: UsageTracker, query: string): Promise<string> {
const logger = new PrefixLogger("copilot-search-tools"); const logger = new PrefixLogger("copilot-search-tools");
console.log("🔧 TOOL CALL: searchRelevantTools", { query }); console.log("🔧 TOOL CALL: searchRelevantTools", { query });
@ -142,6 +143,13 @@ async function searchRelevantTools(query: string): Promise<string> {
return 'No tools found!'; return 'No tools found!';
} }
// track composio search tool usage
usageTracker.track({
type: "COMPOSIO_TOOL_USAGE",
toolSlug: "COMPOSIO_SEARCH_TOOLS",
context: "copilot.search_relevant_tools",
});
// parse results // parse results
const result = composioToolSearchResponseSchema.safeParse(searchResult.data); const result = composioToolSearchResponseSchema.safeParse(searchResult.data);
if (!result.success) { if (!result.success) {
@ -208,6 +216,7 @@ function updateLastUserMessage(
} }
export async function getEditAgentInstructionsResponse( export async function getEditAgentInstructionsResponse(
usageTracker: UsageTracker,
projectId: string, projectId: string,
context: z.infer<typeof CopilotChatContext> | null, context: z.infer<typeof CopilotChatContext> | null,
messages: z.infer<typeof CopilotMessage>[], messages: z.infer<typeof CopilotMessage>[],
@ -232,7 +241,7 @@ export async function getEditAgentInstructionsResponse(
system: COPILOT_INSTRUCTIONS_EDIT_AGENT, system: COPILOT_INSTRUCTIONS_EDIT_AGENT,
messages: messages, messages: messages,
})); }));
const { object } = await generateObject({ const { object, usage } = await generateObject({
model: openai(COPILOT_MODEL), model: openai(COPILOT_MODEL),
messages: [ messages: [
{ {
@ -246,10 +255,20 @@ export async function getEditAgentInstructionsResponse(
}), }),
}); });
// log usage
usageTracker.track({
type: "LLM_USAGE",
modelName: COPILOT_MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "copilot.llm_usage",
});
return object.agent_instructions; return object.agent_instructions;
} }
export async function* streamMultiAgentResponse( export async function* streamMultiAgentResponse(
usageTracker: UsageTracker,
projectId: string, projectId: string,
context: z.infer<typeof CopilotChatContext> | null, context: z.infer<typeof CopilotChatContext> | null,
messages: z.infer<typeof CopilotMessage>[], messages: z.infer<typeof CopilotMessage>[],
@ -297,7 +316,7 @@ export async function* streamMultiAgentResponse(
}), }),
execute: async ({ query }: { query: string }) => { execute: async ({ query }: { query: string }) => {
console.log("🎯 AI TOOL CALL: search_relevant_tools", { query }); console.log("🎯 AI TOOL CALL: search_relevant_tools", { query });
const result = await searchRelevantTools(query); const result = await searchRelevantTools(usageTracker, query);
console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", { console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", {
query, query,
resultLength: result.length resultLength: result.length
@ -341,6 +360,15 @@ export async function* streamMultiAgentResponse(
toolCallId: event.toolCallId, toolCallId: event.toolCallId,
result: event.result, result: event.result,
}; };
} else if (event.type === "step-finish") {
// log usage
usageTracker.track({
type: "LLM_USAGE",
modelName: COPILOT_MODEL,
inputTokens: event.usage.promptTokens,
outputTokens: event.usage.completionTokens,
context: "copilot.llm_usage",
});
} }
} }

View file

@ -2,12 +2,52 @@ import { z } from "zod";
export const SubscriptionPlan = z.enum(["free", "starter", "pro"]); export const SubscriptionPlan = z.enum(["free", "starter", "pro"]);
export const UsageType = z.enum([ export const UsageTypeKey = z.enum([
"copilot_requests", "LLM_USAGE",
"agent_messages", "EMBEDDING_MODEL_USAGE",
"rag_tokens", "COMPOSIO_TOOL_USAGE",
"FIRECRAWL_SCRAPE_USAGE",
]); ]);
export const LLMUsage = z.object({
type: z.literal(UsageTypeKey.Enum.LLM_USAGE),
modelName: z.string(),
inputTokens: z.number().positive(),
outputTokens: z.number().positive(),
context: z.string(),
});
export const EmbeddingModelUsage = z.object({
type: z.literal(UsageTypeKey.Enum.EMBEDDING_MODEL_USAGE),
modelName: z.string(),
tokens: z.number().positive(),
context: z.string(),
});
export const ComposioToolUsage = z.object({
type: z.literal(UsageTypeKey.Enum.COMPOSIO_TOOL_USAGE),
toolSlug: z.string(),
context: z.string(),
});
export const FirecrawlScrapeUsage = z.object({
type: z.literal(UsageTypeKey.Enum.FIRECRAWL_SCRAPE_USAGE),
context: z.string(),
});
export const UsageItem = z.discriminatedUnion("type", [
LLMUsage,
EmbeddingModelUsage,
ComposioToolUsage,
FirecrawlScrapeUsage,
]);
export const LogUsageRequest = z.object({
items: z.array(UsageItem),
});
export const CustomerUsageData = z.record(z.string(), z.number());
export const Customer = z.object({ export const Customer = z.object({
_id: z.string(), _id: z.string(),
userId: z.string(), userId: z.string(),
@ -19,36 +59,23 @@ export const Customer = z.object({
createdAt: z.string().datetime(), createdAt: z.string().datetime(),
updatedAt: z.string().datetime(), updatedAt: z.string().datetime(),
subscriptionPlanUpdatedAt: z.string().datetime().optional(), subscriptionPlanUpdatedAt: z.string().datetime().optional(),
usage: z.record(UsageType, z.number()).optional(), usage: CustomerUsageData.optional(),
usageUpdatedAt: z.string().datetime().optional(), usageUpdatedAt: z.string().datetime().optional(),
}); creditsOverride: z.number().optional(),
maxProjectsOverride: z.number().optional(),
export const LogUsageRequest = z.object({ agentModelsOverride: z.array(z.string()).optional(),
type: UsageType, });
amount: z.number().int().positive(),
});
export const AuthorizeRequest = z.discriminatedUnion("type", [ export const AuthorizeRequest = z.discriminatedUnion("type", [
z.object({
"type": z.literal("use_credits"),
}),
z.object({ z.object({
"type": z.literal("create_project"), "type": z.literal("create_project"),
"data": z.object({ "data": z.object({
"existingProjectCount": z.number(), "existingProjectCount": z.number(),
}), }),
}), }),
z.object({
"type": z.literal("enable_hosted_tool_server"),
"data": z.object({
"existingServerCount": z.number(),
}),
}),
z.object({
"type": z.literal("process_rag"),
"data": z.object({}),
}),
z.object({
"type": z.literal("copilot_request"),
"data": z.object({}),
}),
z.object({ z.object({
"type": z.literal("agent_response"), "type": z.literal("agent_response"),
"data": z.object({ "data": z.object({
@ -63,10 +90,9 @@ export const AuthorizeResponse = z.object({
}); });
export const UsageResponse = z.object({ export const UsageResponse = z.object({
usage: z.record(UsageType, z.object({ sanctionedCredits: z.number(),
usage: z.number(), availableCredits: z.number(),
total: z.number(), usage: CustomerUsageData,
})),
}); });
export const CustomerPortalSessionRequest = z.object({ export const CustomerPortalSessionRequest = z.object({

View file

@ -16,7 +16,7 @@ import crypto from 'crypto';
import path from 'path'; import path from 'path';
import { createOpenAI } from '@ai-sdk/openai'; import { createOpenAI } from '@ai-sdk/openai';
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags'; import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing'; import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common'; import { BillingError } from '@/src/entities/errors/common';
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
@ -75,7 +75,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
} }
} }
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }): Promise<number> { async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }) {
const logger = _logger const logger = _logger
.child(doc._id.toString()) .child(doc._id.toString())
.child(doc.name); .child(doc.name);
@ -95,7 +95,7 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
if (!USE_GEMINI_FILE_PARSING) { if (!USE_GEMINI_FILE_PARSING) {
// Use OpenAI to extract text content // Use OpenAI to extract text content
logger.log("Extracting content using OpenAI"); logger.log("Extracting content using OpenAI");
const { text } = await generateText({ const { text, usage } = await generateText({
model: openai(FILE_PARSING_MODEL), model: openai(FILE_PARSING_MODEL),
system: extractPrompt, system: extractPrompt,
messages: [ messages: [
@ -112,6 +112,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
], ],
}); });
markdown = text; markdown = text;
usageTracker.track({
type: "LLM_USAGE",
modelName: FILE_PARSING_MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "rag.files.llm_usage",
});
} else { } else {
// Use Gemini to extract text content // Use Gemini to extract text content
logger.log("Extracting content using Gemini"); logger.log("Extracting content using Gemini");
@ -127,6 +134,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
extractPrompt, extractPrompt,
]); ]);
markdown = result.response.text(); markdown = result.response.text();
usageTracker.track({
type: "LLM_USAGE",
modelName: FILE_PARSING_MODEL,
inputTokens: result.response.usageMetadata?.promptTokenCount || 0,
outputTokens: result.response.usageMetadata?.candidatesTokenCount || 0,
context: "rag.files.llm_usage",
});
} }
// split into chunks // split into chunks
@ -139,6 +153,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
model: embeddingModel, model: embeddingModel,
values: splits.map((split) => split.pageContent) values: splits.map((split) => split.pageContent)
}); });
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.files.embedding_usage",
});
// store embeddings in qdrant // store embeddings in qdrant
logger.log("Storing embeddings in Qdrant"); logger.log("Storing embeddings in Qdrant");
@ -170,8 +190,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
lastUpdatedAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(),
} }
}); });
return usage.tokens;
} }
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> { async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -339,8 +357,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing // authorize with billing
if (USE_BILLING && billingCustomerId) { if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, { const authResponse = await authorize(billingCustomerId, {
type: "process_rag", type: "use_credits",
data: {},
}); });
if ('error' in authResponse) { if ('error' in authResponse) {
@ -349,16 +366,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
} }
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }; const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
const usageTracker = new UsageTracker();
try { try {
const usedTokens = await runProcessPipeline(logger, job, ldoc); await runProcessPipeline(logger, usageTracker, job, ldoc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
} catch (e: any) { } catch (e: any) {
errors = true; errors = true;
logger.log("Error processing doc:", e); logger.log("Error processing doc:", e);
@ -371,6 +381,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message, error: e.message,
} }
}); });
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
} }
} }

View file

@ -10,7 +10,7 @@ import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from "../lib/utils"; import { PrefixLogger } from "../lib/utils";
import crypto from 'crypto'; import crypto from 'crypto';
import { USE_BILLING } from '../lib/feature_flags'; import { USE_BILLING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing'; import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common'; import { BillingError } from '@/src/entities/errors/common';
const splitter = new RecursiveCharacterTextSplitter({ const splitter = new RecursiveCharacterTextSplitter({
@ -23,7 +23,7 @@ const second = 1000;
const minute = 60 * second; const minute = 60 * second;
const hour = 60 * minute; const hour = 60 * minute;
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> { async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
const logger = _logger const logger = _logger
.child(doc._id.toString()) .child(doc._id.toString())
.child(doc.name); .child(doc.name);
@ -42,6 +42,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
model: embeddingModel, model: embeddingModel,
values: splits.map((split) => split.pageContent) values: splits.map((split) => split.pageContent)
}); });
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.text.embedding_usage",
});
// store embeddings in qdrant // store embeddings in qdrant
logger.log("Storing embeddings in Qdrant"); logger.log("Storing embeddings in Qdrant");
@ -73,8 +79,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
lastUpdatedAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(),
} }
}); });
return usage.tokens;
} }
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> { async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -241,8 +245,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing // authorize with billing
if (USE_BILLING && billingCustomerId) { if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, { const authResponse = await authorize(billingCustomerId, {
type: "process_rag", type: "use_credits",
data: {}
}); });
if ('error' in authResponse) { if ('error' in authResponse) {
@ -250,16 +253,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
} }
} }
const usageTracker = new UsageTracker();
try { try {
const usedTokens = await runProcessPipeline(logger, job, doc); await runProcessPipeline(logger, usageTracker, job, doc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
} catch (e: any) { } catch (e: any) {
errors = true; errors = true;
logger.log("Error processing doc:", e); logger.log("Error processing doc:", e);
@ -272,6 +268,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message, error: e.message,
} }
}); });
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
} }
} }

View file

@ -11,7 +11,7 @@ import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from "../lib/utils"; import { PrefixLogger } from "../lib/utils";
import crypto from 'crypto'; import crypto from 'crypto';
import { USE_BILLING } from '../lib/feature_flags'; import { USE_BILLING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing'; import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common'; import { BillingError } from '@/src/entities/errors/common';
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY }); const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });
@ -41,7 +41,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
} }
} }
async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> { async function runScrapePipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
const logger = _logger const logger = _logger
.child(doc._id.toString()) .child(doc._id.toString())
.child(doc.name); .child(doc.name);
@ -62,6 +62,10 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
} }
return scrapeResult; return scrapeResult;
}, 3); // Retry up to 3 times }, 3); // Retry up to 3 times
usageTracker.track({
type: "FIRECRAWL_SCRAPE_USAGE",
context: "rag.urls.firecrawl_scrape",
});
// split into chunks // split into chunks
logger.log("Splitting into chunks"); logger.log("Splitting into chunks");
@ -73,6 +77,12 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
model: embeddingModel, model: embeddingModel,
values: splits.map((split) => split.pageContent) values: splits.map((split) => split.pageContent)
}); });
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.urls.embedding_usage",
});
// store embeddings in qdrant // store embeddings in qdrant
logger.log("Storing embeddings in Qdrant"); logger.log("Storing embeddings in Qdrant");
@ -104,8 +114,6 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
lastUpdatedAt: new Date().toISOString(), lastUpdatedAt: new Date().toISOString(),
} }
}); });
return usage.tokens;
} }
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> { async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -273,8 +281,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing // authorize with billing
if (USE_BILLING && billingCustomerId) { if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, { const authResponse = await authorize(billingCustomerId, {
type: "process_rag", type: "use_credits",
data: {}
}); });
if ('error' in authResponse) { if ('error' in authResponse) {
@ -282,16 +289,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
} }
} }
const usageTracker = new UsageTracker();
try { try {
const usedTokens = await runScrapePipeline(logger, job, doc); await runScrapePipeline(logger, usageTracker, job, doc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
} catch (e: any) { } catch (e: any) {
errors = true; errors = true;
logger.log("Error processing doc:", e); logger.log("Error processing doc:", e);
@ -304,6 +304,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message, error: e.message,
} }
}); });
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
} }
} }

View file

@ -1,6 +1,6 @@
import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn"; import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn";
import { USE_BILLING } from "@/app/lib/feature_flags"; import { USE_BILLING } from "@/app/lib/feature_flags";
import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing"; import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
import { NotFoundError } from '@/src/entities/errors/common'; import { NotFoundError } from '@/src/entities/errors/common';
import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface"; import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface";
import { streamResponse } from "@/app/lib/agents"; import { streamResponse } from "@/app/lib/agents";
@ -69,6 +69,21 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
if (USE_BILLING) { if (USE_BILLING) {
// get billing customer id for project // get billing customer id for project
billingCustomerId = await getCustomerIdForProject(projectId); billingCustomerId = await getCustomerIdForProject(projectId);
// validate enough credits
const result = await authorize(billingCustomerId, {
type: "use_credits"
});
if (!result.success) {
yield {
type: "error",
error: result.error || 'Billing error',
isBillingError: true,
};
return;
}
// validate model usage
const agentModels = conversation.workflow.agents.reduce((acc, agent) => { const agentModels = conversation.workflow.agents.reduce((acc, agent) => {
acc.push(agent.model); acc.push(agent.model);
return acc; return acc;
@ -111,46 +126,50 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
conversation.workflow.mockTools = data.input.mockTools; conversation.workflow.mockTools = data.input.mockTools;
} }
// init usage tracker
const usageTracker = new UsageTracker();
// call agents runtime and handle generated messages // call agents runtime and handle generated messages
const outputMessages: z.infer<typeof Message>[] = []; try {
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages)) { const outputMessages: z.infer<typeof Message>[] = [];
// handle msg events for await (const event of streamResponse(projectId, conversation.workflow, inputMessages, usageTracker)) {
if ("role" in event) { // handle msg events
// collect generated message if ("role" in event) {
const msg = { // collect generated message
...event, const msg = {
timestamp: new Date().toISOString(), ...event,
}; timestamp: new Date().toISOString(),
outputMessages.push(msg); };
outputMessages.push(msg);
// yield event // yield event
yield { yield {
type: "message", type: "message",
data: msg, data: msg,
}; };
} else {
// save turn data
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
reason: data.reason,
input: data.input,
output: outputMessages,
});
// yield event
yield {
type: "done",
turn,
conversationId,
} }
} }
}
// Log billing usage // save turn data
if (USE_BILLING && billingCustomerId) { const turn = await this.conversationsRepository.addTurn(data.conversationId, {
await logUsage(billingCustomerId, { reason: data.reason,
type: "agent_messages", input: data.input,
amount: outputMessages.length, output: outputMessages,
}); });
// yield event
yield {
type: "done",
turn,
conversationId,
}
} finally {
// Log billing usage
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
} }
} }
} }