billing + credits updates (#202)

This commit is contained in:
Ramnique Singh 2025-08-14 19:59:38 +05:30 committed by GitHub
parent 852e02e49e
commit eccfb4748f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 497 additions and 229 deletions

View file

@ -15,6 +15,7 @@ import { WithStringId } from "../lib/types/types";
import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot";
import { container } from "@/di/container";
import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface";
import { UsageTracker } from "../lib/billing";
const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy');
@ -32,8 +33,7 @@ export async function getCopilotResponseStream(
// Check billing authorization
const authResponse = await authorizeUserAction({
type: 'copilot_request',
data: {},
type: 'use_credits',
});
if (!authResponse.success) {
return { billingError: authResponse.error || 'Billing error' };
@ -75,8 +75,7 @@ export async function getCopilotAgentInstructions(
// Check billing authorization
const authResponse = await authorizeUserAction({
type: 'copilot_request',
data: {},
type: 'use_credits',
});
if (!authResponse.success) {
return { billingError: authResponse.error || 'Billing error' };
@ -93,8 +92,11 @@ export async function getCopilotAgentInstructions(
}
};
const usageTracker = new UsageTracker();
// call copilot api
const agent_instructions = await getEditAgentInstructionsResponse(
usageTracker,
projectId,
request.context,
request.messages,
@ -104,8 +106,7 @@ export async function getCopilotAgentInstructions(
// log the billing usage
if (USE_BILLING) {
await logUsage({
type: 'copilot_requests',
amount: 1,
items: usageTracker.flush(),
});
}

View file

@ -1,4 +1,4 @@
import { getCustomerIdForProject, logUsage } from "@/app/lib/billing";
import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
import { USE_BILLING } from "@/app/lib/feature_flags";
import { redisClient } from "@/app/lib/redis";
import { CopilotAPIRequest } from "@/app/lib/types/copilot_types";
@ -21,6 +21,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
billingCustomerId = await getCustomerIdForProject(projectId);
}
const usageTracker = new UsageTracker();
const encoder = new TextEncoder();
let messageCount = 0;
@ -29,6 +30,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
try {
// Iterate over the copilot stream generator
for await (const event of streamMultiAgentResponse(
usageTracker,
projectId,
context,
messages,
@ -49,21 +51,20 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
}
controller.close();
// increment copilot request count in billing
} catch (error) {
console.error('Error processing copilot stream:', error);
controller.error(error);
} finally {
// log copilot usage
if (USE_BILLING && billingCustomerId) {
try {
await logUsage(billingCustomerId, {
type: "copilot_requests",
amount: 1,
items: usageTracker.flush(),
});
} catch (error) {
console.error("Error logging usage", error);
}
}
} catch (error) {
console.error('Error processing copilot stream:', error);
controller.error(error);
}
},
});

View file

@ -220,10 +220,10 @@ export async function POST(
// log billing usage
if (USE_BILLING && billingCustomerId) {
const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length;
await logUsage(billingCustomerId, {
type: 'agent_messages',
amount: agentMessageCount,
});
// await logUsage(billingCustomerId, {
// type: 'agent_messages',
// amount: agentMessageCount,
// });
}
logger.log(`Turn processing completed successfully`);

View file

@ -3,7 +3,7 @@
import { Progress, Badge, Chip } from "@heroui/react";
import { Button } from "@/components/ui/button";
import { Label } from "@/app/lib/components/label";
import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types";
import { Customer, UsageResponse } from "@/app/lib/types/billing_types";
import { z } from "zod";
import { tokens } from "@/app/styles/design-tokens";
import { SectionHeading } from "@/components/ui/section-heading";
@ -47,6 +47,15 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
const displayStatus = getDisplayStatus(customer.subscriptionStatus);
const planInfo = planDetails[plan];
// Prepare usage metrics data
const usageData = Object.entries(usage.usage)
.map(([type, credits]) => ({
type,
credits,
totalUsedCredits: usage.sanctionedCredits - usage.availableCredits
}))
.sort((a, b) => b.credits - a.credits);
async function handleManageSubscription() {
const returnUrl = new URL('/billing/callback', window.location.origin);
returnUrl.searchParams.set('redirect', window.location.href);
@ -109,48 +118,175 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
</div>
</section>
{/* Usage Metrics Panel */}
{/* Credits Overview Panel */}
<section className="card">
<div className="px-4 pt-4 pb-6">
<SectionHeading>
Usage Metrics
Credits Overview
</SectionHeading>
</div>
<HorizontalDivider />
<div className="p-6 space-y-6">
{Object.entries(usage.usage).map(([type, { usage: used, total }]) => {
const usageType = type as z.infer<typeof UsageType>;
const percentage = Math.min((used / total) * 100, 100);
const isOverLimit = used > total;
<div className="grid grid-cols-1 md:grid-cols-3 gap-6">
<div className="space-y-2">
<Label label="Sanctioned Credits" />
<p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)}>
{usage.sanctionedCredits.toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Total credits allocated to your plan
</p>
</div>
<div className="space-y-2">
<Label label="Used Credits" />
<p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)}>
{(usage.sanctionedCredits - usage.availableCredits).toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Credits consumed so far
</p>
</div>
<div className="space-y-2">
<Label label="Available Credits" />
<p className={clsx(
tokens.typography.sizes.lg,
tokens.typography.weights.semibold,
usage.availableCredits < 0 ? "text-red-500" : clsx(
tokens.colors.light.text.primary,
tokens.colors.dark.text.primary
)
)}>
{usage.availableCredits.toLocaleString()}
</p>
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
Credits remaining for use
</p>
</div>
</div>
{/* Warning for negative credits */}
{usage.availableCredits < 0 && (
<div className="p-4 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg">
<p className={clsx(
tokens.typography.sizes.sm,
"text-red-700 dark:text-red-300"
)}>
You have exceeded your credit limit. Please upgrade your plan or contact support to avoid service interruptions.
</p>
</div>
)}
{/* Warning for high credit usage (>80%) */}
{usage.availableCredits >= 0 && ((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) > 0.8 && (
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
<p className={clsx(
tokens.typography.sizes.sm,
"text-yellow-700 dark:text-yellow-300"
)}>
You have used more than 80% of your credits. Consider upgrading your plan to avoid interruptions.
</p>
</div>
)}
{/* Credits Progress Bar */}
<div className="space-y-2">
<div className="flex justify-between items-center">
<Label label="Credits Usage" />
<span className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
{Math.round(((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100)}%
</span>
</div>
<Progress
size="lg"
value={((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100}
color={usage.availableCredits < 0 ? "danger" : "primary"}
className="h-4"
aria-label="Credits usage"
/>
</div>
</div>
</section>
return (
<div key={type} className="space-y-2">
<div className="flex justify-between items-center">
<div className="space-y-1">
<Label label={type.replace(/_/g, ' ')} />
<p className={clsx(
{/* Usage Metrics Panel */}
<section className="card">
<div className="px-4 pt-4 pb-6">
<SectionHeading>
Usage data
</SectionHeading>
</div>
<HorizontalDivider />
<div className="p-6 space-y-6">
{usageData.length === 0 ? (
<div className="text-center py-8">
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
No usage data yet
</p>
</div>
) : (
usageData.map(({ type, credits, totalUsedCredits }) => {
const percentage = totalUsedCredits > 0 ? (credits / totalUsedCredits) * 100 : 0;
return (
<div key={type} className="space-y-2">
<div className="flex justify-between items-center">
<div className="space-y-1">
<Label label={type.replace(/_/g, ' ')} />
<p className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
{credits.toLocaleString()} credits
</p>
</div>
<span className={clsx(
tokens.typography.sizes.sm,
tokens.colors.light.text.secondary,
tokens.colors.dark.text.secondary
)}>
{used.toLocaleString()} / {total.toLocaleString()}
</p>
{Math.round(percentage)}%
</span>
</div>
{isOverLimit && (
<Badge color="danger" variant="flat">
Over Limit
</Badge>
)}
<Progress
value={percentage}
color="default"
className="h-2"
aria-label={`${type} credits usage`}
/>
</div>
<Progress
value={percentage}
color={isOverLimit ? "danger" : "primary"}
className="h-2"
aria-label={`${type} usage`}
/>
</div>
);
})}
);
})
)}
</div>
</section>
</div>

View file

@ -16,6 +16,7 @@ import { qdrantClient } from '../lib/qdrant';
import { EmbeddingRecord } from "./types/datasource_types";
import { WorkflowAgent, WorkflowTool } from "./types/workflow_types";
import { PrefixLogger } from "./utils";
import { UsageTracker } from "./billing";
// Provider configuration
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
@ -30,6 +31,7 @@ const openai = createOpenAI({
// Helper to handle mock tool responses
export async function invokeMockTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
toolName: string,
args: string,
description: string,
@ -49,18 +51,28 @@ export async function invokeMockTool(
content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.`
}];
const { text } = await generateText({
const { text, usage } = await generateText({
model: openai(MODEL),
messages,
});
logger.log(`generated text: ${text}`);
// track usage
usageTracker.track({
type: "LLM_USAGE",
modelName: MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "agents_runtime.mock_tool",
});
return text;
}
// Helper to handle RAG tool calls
export async function invokeRagTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
query: string,
sourceIds: string[],
@ -81,11 +93,21 @@ export async function invokeRagTool(
logger.log(`k: ${k}`);
// Create embedding for question
const { embedding } = await embed({
const { embedding, usage } = await embed({
model: embeddingModel,
value: query,
});
// track usage
// track usage
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "agents_runtime.rag_tool.embedding_usage",
});
// Fetch all data sources for this project
const sources = await dataSourcesCollection.find({
projectId: projectId,
@ -154,6 +176,7 @@ export async function invokeRagTool(
export async function invokeWebhookTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
name: string,
input: any,
@ -233,6 +256,7 @@ export async function invokeWebhookTool(
// Helper to handle MCP tool calls
export async function invokeMcpTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
name: string,
input: any,
@ -269,6 +293,7 @@ export async function invokeMcpTool(
// Helper to handle composio tool calls
export async function invokeComposioTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
name: string,
composioData: z.infer<typeof WorkflowTool>['composioData'] & {},
@ -299,12 +324,21 @@ export async function invokeComposioTool(
connectedAccountId: connectedAccountId,
});
logger.log(`composio tool result: ${JSON.stringify(result)}`);
// track usage
usageTracker.track({
type: "COMPOSIO_TOOL_USAGE",
toolSlug: slug,
context: "agents_runtime.composio_tool",
});
return result.data;
}
// Helper to create RAG tool
export function createRagTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowAgent>,
projectId: string
): Tool {
@ -321,6 +355,7 @@ export function createRagTool(
async execute(input: { query: string }) {
const results = await invokeRagTool(
logger,
usageTracker,
projectId,
input.query,
config.ragDataSources || [],
@ -337,6 +372,7 @@ export function createRagTool(
// Helper to create a mock tool
export function createMockTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>,
): Tool {
return tool({
@ -353,6 +389,7 @@ export function createMockTool(
try {
const result = await invokeMockTool(
logger,
usageTracker,
config.name,
JSON.stringify(input),
config.description,
@ -374,6 +411,7 @@ export function createMockTool(
// Helper to create a webhook tool
export function createWebhookTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>,
projectId: string,
): Tool {
@ -391,7 +429,7 @@ export function createWebhookTool(
},
async execute(input: any) {
try {
const result = await invokeWebhookTool(logger, projectId, name, input);
const result = await invokeWebhookTool(logger, usageTracker, projectId, name, input);
return JSON.stringify({
result,
});
@ -408,6 +446,7 @@ export function createWebhookTool(
// Helper to create an mcp tool
export function createMcpTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>,
projectId: string
): Tool {
@ -425,7 +464,7 @@ export function createMcpTool(
},
async execute(input: any) {
try {
const result = await invokeMcpTool(logger, projectId, name, input, mcpServerName || '');
const result = await invokeMcpTool(logger, usageTracker, projectId, name, input, mcpServerName || '');
return JSON.stringify({
result,
});
@ -442,6 +481,7 @@ export function createMcpTool(
// Helper to create a composio tool
export function createComposioTool(
logger: PrefixLogger,
usageTracker: UsageTracker,
config: z.infer<typeof WorkflowTool>,
projectId: string
): Tool {
@ -463,7 +503,7 @@ export function createComposioTool(
},
async execute(input: any) {
try {
const result = await invokeComposioTool(logger, projectId, name, composioData, input);
const result = await invokeComposioTool(logger, usageTracker, projectId, name, composioData, input);
return JSON.stringify({
result,
});
@ -479,6 +519,7 @@ export function createComposioTool(
export function createTools(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
workflow: { tools: z.infer<typeof WorkflowTool>[] },
toolConfig: Record<string, z.infer<typeof WorkflowTool>>,
@ -492,16 +533,16 @@ export function createTools(
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
if (config.mockTool) {
tools[toolName] = createMockTool(logger, config);
tools[toolName] = createMockTool(logger, usageTracker, config);
toolLogger.log(`✓ created mock tool: ${toolName}`);
} else if (config.isMcp) {
tools[toolName] = createMcpTool(logger, config, projectId);
tools[toolName] = createMcpTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`);
} else if (config.isComposio) {
tools[toolName] = createComposioTool(logger, config, projectId);
tools[toolName] = createComposioTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created composio tool: ${toolName}`);
} else {
tools[toolName] = createWebhookTool(logger, config, projectId);
tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId);
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
}
}

View file

@ -12,6 +12,8 @@ import { ConnectedEntity, sanitizeTextWithMentions, Workflow, WorkflowAgent, Wor
import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions";
import { PrefixLogger } from "./utils";
import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types";
import { UsageTracker } from "./billing";
// Native handoff support
import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs";
import { PipelineStateManager } from "./pipeline-state-manager";
@ -78,14 +80,6 @@ const openai = createOpenAI({
baseURL: PROVIDER_BASE_URL,
});
const ZUsage = z.object({
tokens: z.object({
total: z.number(),
prompt: z.number(),
completion: z.number(),
}),
});
const ZOutMessage = z.union([
AssistantMessage,
AssistantMessageWithToolCalls,
@ -95,6 +89,7 @@ const ZOutMessage = z.union([
// Helper to create an agent
function createAgent(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
config: z.infer<typeof WorkflowAgent>,
tools: Record<string, Tool>,
@ -145,7 +140,7 @@ ${CHILD_TRANSFER_RELATED_INSTRUCTIONS}
// Add RAG tool if needed
if (config.ragDataSources?.length) {
const ragTool = createRagTool(logger, config, projectId);
const ragTool = createRagTool(logger, usageTracker, config, projectId);
agentTools.push(ragTool);
// update instructions to include RAG instructions
@ -269,8 +264,8 @@ function getStartOfTurnAgentName(
// Logs an event and then yields it
async function* emitEvent(
logger: PrefixLogger,
event: z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>,
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
event: z.infer<typeof ZOutMessage>,
): AsyncIterable<z.infer<typeof ZOutMessage>> {
logger.log(`-> emitting event: ${JSON.stringify(event)}`);
yield event;
return;
@ -321,30 +316,6 @@ class AgentTransferCounter {
}
}
class UsageTracker {
private usage: {
total: number;
prompt: number;
completion: number;
} = { total: 0, prompt: 0, completion: 0 };
increment(total: number, prompt: number, completion: number): void {
this.usage.total += total;
this.usage.prompt += prompt;
this.usage.completion += completion;
}
get(): { total: number, prompt: number, completion: number } {
return this.usage;
}
asEvent(): z.infer<typeof ZUsage> {
return {
tokens: this.usage,
};
}
}
function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) {
logger = logger.child(`ensureSystemMessage`);
@ -396,7 +367,7 @@ function mapConfig(workflow: z.infer<typeof Workflow>): {
return { agentConfig, toolConfig, promptConfig, pipelineConfig };
}
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage>> {
// find the greeting prompt
const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?';
logger.log(`greeting turn: ${prompt}`);
@ -408,15 +379,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof
agentName: workflow.startAgent,
responseType: 'external',
});
// emit final usage information
yield* emitEvent(logger, new UsageTracker().asEvent());
}
// Enhanced agent creation with native handoff support
function createAgentsWithNativeHandoffs(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
workflow: z.infer<typeof Workflow>,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
@ -447,6 +416,7 @@ function createAgentsWithNativeHandoffs(
const { agent, entities } = createAgent(
logger,
usageTracker,
projectId,
config,
tools,
@ -560,6 +530,7 @@ function createAgentsWithNativeHandoffs(
// Legacy agent creation (existing implementation)
function createAgentsLegacy(
logger: PrefixLogger,
usageTracker: UsageTracker,
projectId: string,
workflow: z.infer<typeof Workflow>,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
@ -595,6 +566,7 @@ function createAgentsLegacy(
const { agent, entities } = createAgent(
logger,
usageTracker,
projectId,
config,
tools,
@ -763,12 +735,13 @@ function maybeInjectGiveUpControlInstructions(
// Handle raw model stream events
async function* handleRawModelStreamEvent(
event: any,
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
agentName: string,
turnMsgs: z.infer<typeof Message>[],
usageTracker: UsageTracker,
eventLogger: PrefixLogger,
getAgentState?: (agentName: string) => AgentState
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
): AsyncIterable<z.infer<typeof ZOutMessage>> {
if (event.data.type === 'response_done') {
// Count tool calls (excluding transfer_to_* calls)
const toolCallCount = event.data.response.output.filter(
@ -809,12 +782,13 @@ async function* handleRawModelStreamEvent(
}
// update usage information
usageTracker.increment(
event.data.response.usage.totalTokens,
event.data.response.usage.inputTokens,
event.data.response.usage.outputTokens
);
eventLogger.log(`updated usage information: ${JSON.stringify(usageTracker.get())}`);
usageTracker.track({
type: "LLM_USAGE",
modelName: agentConfig[agentName]?.model || "unknown",
inputTokens: event.data.response.usage.inputTokens,
outputTokens: event.data.response.usage.outputTokens,
context: "agents_runtime.llm_usage",
});
}
}
@ -833,7 +807,7 @@ async function* handleNativeHandoffEvent(
originalHandoffs: Record<string, any[]>,
eventLogger: PrefixLogger,
loopLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string; shouldContinue?: boolean }> {
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string; shouldContinue?: boolean }> {
eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
// skip if its the same agent
@ -943,7 +917,7 @@ async function* handleHandoffEvent(
originalHandoffs: Record<string, Agent[]>,
eventLogger: PrefixLogger,
loopLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string }> {
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string }> {
eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
// skip if its the same agent
@ -1010,7 +984,7 @@ async function* handleToolCallResult(
event: any,
turnMsgs: z.infer<typeof Message>[],
eventLogger: PrefixLogger
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
): AsyncIterable<z.infer<typeof ZOutMessage>> {
const m: z.infer<typeof Message> = {
role: 'tool',
content: event.item.rawItem.output.text,
@ -1039,7 +1013,7 @@ async function* handleMessageOutput(
eventLogger: PrefixLogger,
loopLogger: PrefixLogger,
getAgentState: (agentName: string) => AgentState
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string | null; shouldContinue: boolean }> {
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string | null; shouldContinue: boolean }> {
// check response visibility - could be an agent or pipeline
const agentConfigObj = agentConfig[agentName];
const pipelineConfigObj = pipelineConfig[agentName];
@ -1243,7 +1217,8 @@ export async function* streamResponse(
projectId: string,
workflow: z.infer<typeof Workflow>,
messages: z.infer<typeof Message>[],
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
usageTracker: UsageTracker,
): AsyncIterable<z.infer<typeof ZOutMessage>> {
// Divider log for tracking agent loop start
console.log('-------------------- AGENT LOOP START --------------------');
// set up logging
@ -1275,11 +1250,11 @@ export async function* streamResponse(
logger.log(`initialized stack: ${JSON.stringify(stack)}`);
// create tools
const tools = createTools(logger, projectId, workflow, toolConfig);
const tools = createTools(logger, usageTracker, projectId, workflow, toolConfig);
// create agents with feature flag support
const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy;
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, usageTracker, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`);
@ -1296,7 +1271,6 @@ export async function* streamResponse(
let agentName: string | null = startOfTurnAgentName;
// start the turn loop
const usageTracker = new UsageTracker();
const turnMsgs: z.infer<typeof Message>[] = [...messages];
// Initialize agent state tracking for tool call completion
@ -1390,7 +1364,7 @@ export async function* streamResponse(
switch (event.type) {
case 'raw_model_stream_event':
yield* handleRawModelStreamEvent(event, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
yield* handleRawModelStreamEvent(event, agentConfig, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
break;
case 'run_item_stream_event':
@ -1523,9 +1497,6 @@ export async function* streamResponse(
}
}
// emit usage information
yield* emitEvent(logger, usageTracker.asEvent());
}
// this is a sync version of streamResponse
@ -1535,8 +1506,10 @@ export async function getResponse(
messages: z.infer<typeof Message>[],
): Promise<{
messages: z.infer<typeof ZOutMessage>[],
usage: z.infer<typeof ZUsage>,
usage: any,
}> {
throw new Error("Not implemented!");
/*
const out: z.infer<typeof ZOutMessage>[] = [];
let usage: z.infer<typeof ZUsage> = {
tokens: {
@ -1554,4 +1527,5 @@ export async function getResponse(
}
}
return { messages: out, usage };
*/
}

View file

@ -1,6 +1,6 @@
import { WithStringId } from './types/types';
import { z } from 'zod';
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types';
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse, UsageItem } from './types/billing_types';
import { ObjectId } from 'mongodb';
import { projectsCollection, usersCollection } from './mongodb';
import { redirect } from 'next/navigation';
@ -23,6 +23,20 @@ const GUEST_BILLING_CUSTOMER = {
updatedAt: new Date().toISOString(),
};
export class UsageTracker{
private items: z.infer<typeof UsageItem>[] = [];
track(item: z.infer<typeof UsageItem>) {
this.items.push(item);
}
flush(): z.infer<typeof UsageItem>[] {
const items = this.items;
this.items = [];
return items;
}
}
export async function getCustomerIdForProject(projectId: string): Promise<string> {
const project = await projectsCollection.findOne({ _id: projectId });
if (!project) {
@ -111,6 +125,7 @@ export async function authorize(customerId: string, request: z.infer<typeof Auth
}
export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) {
console.log(`logging billing usage for customer ${customerId}`, JSON.stringify(request));
const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, {
method: 'POST',
headers: {

View file

@ -13,6 +13,7 @@ import { COPILOT_MULTI_AGENT_EXAMPLE_1 } from "./example_multi_agent_1";
import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow";
import { USE_COMPOSIO_TOOLS } from "../feature_flags";
import { composio, getTool } from "../composio/composio";
import { UsageTracker } from "../billing";
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined;
@ -119,7 +120,7 @@ ${JSON.stringify(simplifiedDataSources)}
return prompt;
}
async function searchRelevantTools(query: string): Promise<string> {
async function searchRelevantTools(usageTracker: UsageTracker, query: string): Promise<string> {
const logger = new PrefixLogger("copilot-search-tools");
console.log("🔧 TOOL CALL: searchRelevantTools", { query });
@ -142,6 +143,13 @@ async function searchRelevantTools(query: string): Promise<string> {
return 'No tools found!';
}
// track composio search tool usage
usageTracker.track({
type: "COMPOSIO_TOOL_USAGE",
toolSlug: "COMPOSIO_SEARCH_TOOLS",
context: "copilot.search_relevant_tools",
});
// parse results
const result = composioToolSearchResponseSchema.safeParse(searchResult.data);
if (!result.success) {
@ -208,6 +216,7 @@ function updateLastUserMessage(
}
export async function getEditAgentInstructionsResponse(
usageTracker: UsageTracker,
projectId: string,
context: z.infer<typeof CopilotChatContext> | null,
messages: z.infer<typeof CopilotMessage>[],
@ -232,7 +241,7 @@ export async function getEditAgentInstructionsResponse(
system: COPILOT_INSTRUCTIONS_EDIT_AGENT,
messages: messages,
}));
const { object } = await generateObject({
const { object, usage } = await generateObject({
model: openai(COPILOT_MODEL),
messages: [
{
@ -246,10 +255,20 @@ export async function getEditAgentInstructionsResponse(
}),
});
// log usage
usageTracker.track({
type: "LLM_USAGE",
modelName: COPILOT_MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "copilot.llm_usage",
});
return object.agent_instructions;
}
export async function* streamMultiAgentResponse(
usageTracker: UsageTracker,
projectId: string,
context: z.infer<typeof CopilotChatContext> | null,
messages: z.infer<typeof CopilotMessage>[],
@ -297,7 +316,7 @@ export async function* streamMultiAgentResponse(
}),
execute: async ({ query }: { query: string }) => {
console.log("🎯 AI TOOL CALL: search_relevant_tools", { query });
const result = await searchRelevantTools(query);
const result = await searchRelevantTools(usageTracker, query);
console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", {
query,
resultLength: result.length
@ -341,6 +360,15 @@ export async function* streamMultiAgentResponse(
toolCallId: event.toolCallId,
result: event.result,
};
} else if (event.type === "step-finish") {
// log usage
usageTracker.track({
type: "LLM_USAGE",
modelName: COPILOT_MODEL,
inputTokens: event.usage.promptTokens,
outputTokens: event.usage.completionTokens,
context: "copilot.llm_usage",
});
}
}

View file

@ -2,12 +2,52 @@ import { z } from "zod";
export const SubscriptionPlan = z.enum(["free", "starter", "pro"]);
export const UsageType = z.enum([
"copilot_requests",
"agent_messages",
"rag_tokens",
export const UsageTypeKey = z.enum([
"LLM_USAGE",
"EMBEDDING_MODEL_USAGE",
"COMPOSIO_TOOL_USAGE",
"FIRECRAWL_SCRAPE_USAGE",
]);
export const LLMUsage = z.object({
type: z.literal(UsageTypeKey.Enum.LLM_USAGE),
modelName: z.string(),
inputTokens: z.number().positive(),
outputTokens: z.number().positive(),
context: z.string(),
});
export const EmbeddingModelUsage = z.object({
type: z.literal(UsageTypeKey.Enum.EMBEDDING_MODEL_USAGE),
modelName: z.string(),
tokens: z.number().positive(),
context: z.string(),
});
export const ComposioToolUsage = z.object({
type: z.literal(UsageTypeKey.Enum.COMPOSIO_TOOL_USAGE),
toolSlug: z.string(),
context: z.string(),
});
export const FirecrawlScrapeUsage = z.object({
type: z.literal(UsageTypeKey.Enum.FIRECRAWL_SCRAPE_USAGE),
context: z.string(),
});
export const UsageItem = z.discriminatedUnion("type", [
LLMUsage,
EmbeddingModelUsage,
ComposioToolUsage,
FirecrawlScrapeUsage,
]);
export const LogUsageRequest = z.object({
items: z.array(UsageItem),
});
export const CustomerUsageData = z.record(z.string(), z.number());
export const Customer = z.object({
_id: z.string(),
userId: z.string(),
@ -19,36 +59,23 @@ export const Customer = z.object({
createdAt: z.string().datetime(),
updatedAt: z.string().datetime(),
subscriptionPlanUpdatedAt: z.string().datetime().optional(),
usage: z.record(UsageType, z.number()).optional(),
usage: CustomerUsageData.optional(),
usageUpdatedAt: z.string().datetime().optional(),
});
export const LogUsageRequest = z.object({
type: UsageType,
amount: z.number().int().positive(),
});
creditsOverride: z.number().optional(),
maxProjectsOverride: z.number().optional(),
agentModelsOverride: z.array(z.string()).optional(),
});
export const AuthorizeRequest = z.discriminatedUnion("type", [
z.object({
"type": z.literal("use_credits"),
}),
z.object({
"type": z.literal("create_project"),
"data": z.object({
"existingProjectCount": z.number(),
}),
}),
z.object({
"type": z.literal("enable_hosted_tool_server"),
"data": z.object({
"existingServerCount": z.number(),
}),
}),
z.object({
"type": z.literal("process_rag"),
"data": z.object({}),
}),
z.object({
"type": z.literal("copilot_request"),
"data": z.object({}),
}),
z.object({
"type": z.literal("agent_response"),
"data": z.object({
@ -63,10 +90,9 @@ export const AuthorizeResponse = z.object({
});
export const UsageResponse = z.object({
usage: z.record(UsageType, z.object({
usage: z.number(),
total: z.number(),
})),
sanctionedCredits: z.number(),
availableCredits: z.number(),
usage: CustomerUsageData,
});
export const CustomerPortalSessionRequest = z.object({

View file

@ -16,7 +16,7 @@ import crypto from 'crypto';
import path from 'path';
import { createOpenAI } from '@ai-sdk/openai';
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common';
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
@ -75,7 +75,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
}
}
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }): Promise<number> {
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }) {
const logger = _logger
.child(doc._id.toString())
.child(doc.name);
@ -95,7 +95,7 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
if (!USE_GEMINI_FILE_PARSING) {
// Use OpenAI to extract text content
logger.log("Extracting content using OpenAI");
const { text } = await generateText({
const { text, usage } = await generateText({
model: openai(FILE_PARSING_MODEL),
system: extractPrompt,
messages: [
@ -112,6 +112,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
],
});
markdown = text;
usageTracker.track({
type: "LLM_USAGE",
modelName: FILE_PARSING_MODEL,
inputTokens: usage.promptTokens,
outputTokens: usage.completionTokens,
context: "rag.files.llm_usage",
});
} else {
// Use Gemini to extract text content
logger.log("Extracting content using Gemini");
@ -127,6 +134,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
extractPrompt,
]);
markdown = result.response.text();
usageTracker.track({
type: "LLM_USAGE",
modelName: FILE_PARSING_MODEL,
inputTokens: result.response.usageMetadata?.promptTokenCount || 0,
outputTokens: result.response.usageMetadata?.candidatesTokenCount || 0,
context: "rag.files.llm_usage",
});
}
// split into chunks
@ -139,6 +153,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
model: embeddingModel,
values: splits.map((split) => split.pageContent)
});
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.files.embedding_usage",
});
// store embeddings in qdrant
logger.log("Storing embeddings in Qdrant");
@ -170,8 +190,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
lastUpdatedAt: new Date().toISOString(),
}
});
return usage.tokens;
}
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -339,8 +357,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing
if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, {
type: "process_rag",
data: {},
type: "use_credits",
});
if ('error' in authResponse) {
@ -349,16 +366,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
}
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
const usageTracker = new UsageTracker();
try {
const usedTokens = await runProcessPipeline(logger, job, ldoc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
await runProcessPipeline(logger, usageTracker, job, ldoc);
} catch (e: any) {
errors = true;
logger.log("Error processing doc:", e);
@ -371,6 +381,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message,
}
});
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
}
}

View file

@ -10,7 +10,7 @@ import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from "../lib/utils";
import crypto from 'crypto';
import { USE_BILLING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common';
const splitter = new RecursiveCharacterTextSplitter({
@ -23,7 +23,7 @@ const second = 1000;
const minute = 60 * second;
const hour = 60 * minute;
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
const logger = _logger
.child(doc._id.toString())
.child(doc.name);
@ -42,6 +42,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
model: embeddingModel,
values: splits.map((split) => split.pageContent)
});
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.text.embedding_usage",
});
// store embeddings in qdrant
logger.log("Storing embeddings in Qdrant");
@ -73,8 +79,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
lastUpdatedAt: new Date().toISOString(),
}
});
return usage.tokens;
}
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -241,8 +245,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing
if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, {
type: "process_rag",
data: {}
type: "use_credits",
});
if ('error' in authResponse) {
@ -250,16 +253,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
}
}
const usageTracker = new UsageTracker();
try {
const usedTokens = await runProcessPipeline(logger, job, doc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
await runProcessPipeline(logger, usageTracker, job, doc);
} catch (e: any) {
errors = true;
logger.log("Error processing doc:", e);
@ -272,6 +268,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message,
}
});
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
}
}

View file

@ -11,7 +11,7 @@ import { qdrantClient } from '../lib/qdrant';
import { PrefixLogger } from "../lib/utils";
import crypto from 'crypto';
import { USE_BILLING } from '../lib/feature_flags';
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
import { BillingError } from '@/src/entities/errors/common';
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });
@ -41,7 +41,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
}
}
async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
async function runScrapePipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
const logger = _logger
.child(doc._id.toString())
.child(doc.name);
@ -62,6 +62,10 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
}
return scrapeResult;
}, 3); // Retry up to 3 times
usageTracker.track({
type: "FIRECRAWL_SCRAPE_USAGE",
context: "rag.urls.firecrawl_scrape",
});
// split into chunks
logger.log("Splitting into chunks");
@ -73,6 +77,12 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
model: embeddingModel,
values: splits.map((split) => split.pageContent)
});
usageTracker.track({
type: "EMBEDDING_MODEL_USAGE",
modelName: embeddingModel.modelId,
tokens: usage.tokens,
context: "rag.urls.embedding_usage",
});
// store embeddings in qdrant
logger.log("Storing embeddings in Qdrant");
@ -104,8 +114,6 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
lastUpdatedAt: new Date().toISOString(),
}
});
return usage.tokens;
}
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
@ -273,8 +281,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
// authorize with billing
if (USE_BILLING && billingCustomerId) {
const authResponse = await authorize(billingCustomerId, {
type: "process_rag",
data: {}
type: "use_credits",
});
if ('error' in authResponse) {
@ -282,16 +289,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
}
}
const usageTracker = new UsageTracker();
try {
const usedTokens = await runScrapePipeline(logger, job, doc);
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "rag_tokens",
amount: usedTokens,
});
}
await runScrapePipeline(logger, usageTracker, job, doc);
} catch (e: any) {
errors = true;
logger.log("Error processing doc:", e);
@ -304,6 +304,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
error: e.message,
}
});
} finally {
// log usage in billing
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
}
}

View file

@ -1,6 +1,6 @@
import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn";
import { USE_BILLING } from "@/app/lib/feature_flags";
import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing";
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
import { NotFoundError } from '@/src/entities/errors/common';
import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface";
import { streamResponse } from "@/app/lib/agents";
@ -69,6 +69,21 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
if (USE_BILLING) {
// get billing customer id for project
billingCustomerId = await getCustomerIdForProject(projectId);
// validate enough credits
const result = await authorize(billingCustomerId, {
type: "use_credits"
});
if (!result.success) {
yield {
type: "error",
error: result.error || 'Billing error',
isBillingError: true,
};
return;
}
// validate model usage
const agentModels = conversation.workflow.agents.reduce((acc, agent) => {
acc.push(agent.model);
return acc;
@ -111,46 +126,50 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
conversation.workflow.mockTools = data.input.mockTools;
}
// init usage tracker
const usageTracker = new UsageTracker();
// call agents runtime and handle generated messages
const outputMessages: z.infer<typeof Message>[] = [];
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages)) {
// handle msg events
if ("role" in event) {
// collect generated message
const msg = {
...event,
timestamp: new Date().toISOString(),
};
outputMessages.push(msg);
try {
const outputMessages: z.infer<typeof Message>[] = [];
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages, usageTracker)) {
// handle msg events
if ("role" in event) {
// collect generated message
const msg = {
...event,
timestamp: new Date().toISOString(),
};
outputMessages.push(msg);
// yield event
yield {
type: "message",
data: msg,
};
} else {
// save turn data
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
reason: data.reason,
input: data.input,
output: outputMessages,
});
// yield event
yield {
type: "done",
turn,
conversationId,
// yield event
yield {
type: "message",
data: msg,
};
}
}
}
// Log billing usage
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
type: "agent_messages",
amount: outputMessages.length,
// save turn data
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
reason: data.reason,
input: data.input,
output: outputMessages,
});
// yield event
yield {
type: "done",
turn,
conversationId,
}
} finally {
// Log billing usage
if (USE_BILLING && billingCustomerId) {
await logUsage(billingCustomerId, {
items: usageTracker.flush(),
});
}
}
}
}