mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-08 23:02:41 +02:00
billing + credits updates (#202)
This commit is contained in:
parent
852e02e49e
commit
eccfb4748f
13 changed files with 497 additions and 229 deletions
|
|
@ -15,6 +15,7 @@ import { WithStringId } from "../lib/types/types";
|
|||
import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot";
|
||||
import { container } from "@/di/container";
|
||||
import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface";
|
||||
import { UsageTracker } from "../lib/billing";
|
||||
|
||||
const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy');
|
||||
|
||||
|
|
@ -32,8 +33,7 @@ export async function getCopilotResponseStream(
|
|||
|
||||
// Check billing authorization
|
||||
const authResponse = await authorizeUserAction({
|
||||
type: 'copilot_request',
|
||||
data: {},
|
||||
type: 'use_credits',
|
||||
});
|
||||
if (!authResponse.success) {
|
||||
return { billingError: authResponse.error || 'Billing error' };
|
||||
|
|
@ -75,8 +75,7 @@ export async function getCopilotAgentInstructions(
|
|||
|
||||
// Check billing authorization
|
||||
const authResponse = await authorizeUserAction({
|
||||
type: 'copilot_request',
|
||||
data: {},
|
||||
type: 'use_credits',
|
||||
});
|
||||
if (!authResponse.success) {
|
||||
return { billingError: authResponse.error || 'Billing error' };
|
||||
|
|
@ -93,8 +92,11 @@ export async function getCopilotAgentInstructions(
|
|||
}
|
||||
};
|
||||
|
||||
const usageTracker = new UsageTracker();
|
||||
|
||||
// call copilot api
|
||||
const agent_instructions = await getEditAgentInstructionsResponse(
|
||||
usageTracker,
|
||||
projectId,
|
||||
request.context,
|
||||
request.messages,
|
||||
|
|
@ -104,8 +106,7 @@ export async function getCopilotAgentInstructions(
|
|||
// log the billing usage
|
||||
if (USE_BILLING) {
|
||||
await logUsage({
|
||||
type: 'copilot_requests',
|
||||
amount: 1,
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { getCustomerIdForProject, logUsage } from "@/app/lib/billing";
|
||||
import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
|
||||
import { USE_BILLING } from "@/app/lib/feature_flags";
|
||||
import { redisClient } from "@/app/lib/redis";
|
||||
import { CopilotAPIRequest } from "@/app/lib/types/copilot_types";
|
||||
|
|
@ -21,6 +21,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
|||
billingCustomerId = await getCustomerIdForProject(projectId);
|
||||
}
|
||||
|
||||
const usageTracker = new UsageTracker();
|
||||
const encoder = new TextEncoder();
|
||||
let messageCount = 0;
|
||||
|
||||
|
|
@ -29,6 +30,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
|||
try {
|
||||
// Iterate over the copilot stream generator
|
||||
for await (const event of streamMultiAgentResponse(
|
||||
usageTracker,
|
||||
projectId,
|
||||
context,
|
||||
messages,
|
||||
|
|
@ -49,21 +51,20 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
|||
}
|
||||
|
||||
controller.close();
|
||||
|
||||
// increment copilot request count in billing
|
||||
} catch (error) {
|
||||
console.error('Error processing copilot stream:', error);
|
||||
controller.error(error);
|
||||
} finally {
|
||||
// log copilot usage
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
try {
|
||||
await logUsage(billingCustomerId, {
|
||||
type: "copilot_requests",
|
||||
amount: 1,
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error logging usage", error);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing copilot stream:', error);
|
||||
controller.error(error);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -220,10 +220,10 @@ export async function POST(
|
|||
// log billing usage
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length;
|
||||
await logUsage(billingCustomerId, {
|
||||
type: 'agent_messages',
|
||||
amount: agentMessageCount,
|
||||
});
|
||||
// await logUsage(billingCustomerId, {
|
||||
// type: 'agent_messages',
|
||||
// amount: agentMessageCount,
|
||||
// });
|
||||
}
|
||||
|
||||
logger.log(`Turn processing completed successfully`);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import { Progress, Badge, Chip } from "@heroui/react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/app/lib/components/label";
|
||||
import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types";
|
||||
import { Customer, UsageResponse } from "@/app/lib/types/billing_types";
|
||||
import { z } from "zod";
|
||||
import { tokens } from "@/app/styles/design-tokens";
|
||||
import { SectionHeading } from "@/components/ui/section-heading";
|
||||
|
|
@ -47,6 +47,15 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
|
|||
const displayStatus = getDisplayStatus(customer.subscriptionStatus);
|
||||
const planInfo = planDetails[plan];
|
||||
|
||||
// Prepare usage metrics data
|
||||
const usageData = Object.entries(usage.usage)
|
||||
.map(([type, credits]) => ({
|
||||
type,
|
||||
credits,
|
||||
totalUsedCredits: usage.sanctionedCredits - usage.availableCredits
|
||||
}))
|
||||
.sort((a, b) => b.credits - a.credits);
|
||||
|
||||
async function handleManageSubscription() {
|
||||
const returnUrl = new URL('/billing/callback', window.location.origin);
|
||||
returnUrl.searchParams.set('redirect', window.location.href);
|
||||
|
|
@ -109,48 +118,175 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
|
|||
</div>
|
||||
</section>
|
||||
|
||||
{/* Usage Metrics Panel */}
|
||||
{/* Credits Overview Panel */}
|
||||
<section className="card">
|
||||
<div className="px-4 pt-4 pb-6">
|
||||
<SectionHeading>
|
||||
Usage Metrics
|
||||
Credits Overview
|
||||
</SectionHeading>
|
||||
</div>
|
||||
<HorizontalDivider />
|
||||
<div className="p-6 space-y-6">
|
||||
{Object.entries(usage.usage).map(([type, { usage: used, total }]) => {
|
||||
const usageType = type as z.infer<typeof UsageType>;
|
||||
const percentage = Math.min((used / total) * 100, 100);
|
||||
const isOverLimit = used > total;
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-6">
|
||||
<div className="space-y-2">
|
||||
<Label label="Sanctioned Credits" />
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.lg,
|
||||
tokens.typography.weights.semibold,
|
||||
tokens.colors.light.text.primary,
|
||||
tokens.colors.dark.text.primary
|
||||
)}>
|
||||
{usage.sanctionedCredits.toLocaleString()}
|
||||
</p>
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
Total credits allocated to your plan
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label label="Used Credits" />
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.lg,
|
||||
tokens.typography.weights.semibold,
|
||||
tokens.colors.light.text.primary,
|
||||
tokens.colors.dark.text.primary
|
||||
)}>
|
||||
{(usage.sanctionedCredits - usage.availableCredits).toLocaleString()}
|
||||
</p>
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
Credits consumed so far
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label label="Available Credits" />
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.lg,
|
||||
tokens.typography.weights.semibold,
|
||||
usage.availableCredits < 0 ? "text-red-500" : clsx(
|
||||
tokens.colors.light.text.primary,
|
||||
tokens.colors.dark.text.primary
|
||||
)
|
||||
)}>
|
||||
{usage.availableCredits.toLocaleString()}
|
||||
</p>
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
Credits remaining for use
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
return (
|
||||
<div key={type} className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<div className="space-y-1">
|
||||
<Label label={type.replace(/_/g, ' ')} />
|
||||
<p className={clsx(
|
||||
{/* Warning for negative credits */}
|
||||
{usage.availableCredits < 0 && (
|
||||
<div className="p-4 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg">
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
"text-red-700 dark:text-red-300"
|
||||
)}>
|
||||
⚠️ You have exceeded your credit limit. Please upgrade your plan or contact support to avoid service interruptions.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Warning for high credit usage (>80%) */}
|
||||
{usage.availableCredits >= 0 && ((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) > 0.8 && (
|
||||
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
"text-yellow-700 dark:text-yellow-300"
|
||||
)}>
|
||||
⚠️ You have used more than 80% of your credits. Consider upgrading your plan to avoid interruptions.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Credits Progress Bar */}
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<Label label="Credits Usage" />
|
||||
<span className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
{Math.round(((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
<Progress
|
||||
size="lg"
|
||||
value={((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100}
|
||||
color={usage.availableCredits < 0 ? "danger" : "primary"}
|
||||
className="h-4"
|
||||
aria-label="Credits usage"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* Usage Metrics Panel */}
|
||||
<section className="card">
|
||||
<div className="px-4 pt-4 pb-6">
|
||||
<SectionHeading>
|
||||
Usage data
|
||||
</SectionHeading>
|
||||
</div>
|
||||
<HorizontalDivider />
|
||||
<div className="p-6 space-y-6">
|
||||
{usageData.length === 0 ? (
|
||||
<div className="text-center py-8">
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
No usage data yet
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
usageData.map(({ type, credits, totalUsedCredits }) => {
|
||||
const percentage = totalUsedCredits > 0 ? (credits / totalUsedCredits) * 100 : 0;
|
||||
|
||||
return (
|
||||
<div key={type} className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<div className="space-y-1">
|
||||
<Label label={type.replace(/_/g, ' ')} />
|
||||
<p className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
{credits.toLocaleString()} credits
|
||||
</p>
|
||||
</div>
|
||||
<span className={clsx(
|
||||
tokens.typography.sizes.sm,
|
||||
tokens.colors.light.text.secondary,
|
||||
tokens.colors.dark.text.secondary
|
||||
)}>
|
||||
{used.toLocaleString()} / {total.toLocaleString()}
|
||||
</p>
|
||||
{Math.round(percentage)}%
|
||||
</span>
|
||||
</div>
|
||||
{isOverLimit && (
|
||||
<Badge color="danger" variant="flat">
|
||||
Over Limit
|
||||
</Badge>
|
||||
)}
|
||||
<Progress
|
||||
value={percentage}
|
||||
color="default"
|
||||
className="h-2"
|
||||
aria-label={`${type} credits usage`}
|
||||
/>
|
||||
</div>
|
||||
<Progress
|
||||
value={percentage}
|
||||
color={isOverLimit ? "danger" : "primary"}
|
||||
className="h-2"
|
||||
aria-label={`${type} usage`}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
);
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import { qdrantClient } from '../lib/qdrant';
|
|||
import { EmbeddingRecord } from "./types/datasource_types";
|
||||
import { WorkflowAgent, WorkflowTool } from "./types/workflow_types";
|
||||
import { PrefixLogger } from "./utils";
|
||||
import { UsageTracker } from "./billing";
|
||||
|
||||
// Provider configuration
|
||||
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||
|
|
@ -30,6 +31,7 @@ const openai = createOpenAI({
|
|||
// Helper to handle mock tool responses
|
||||
export async function invokeMockTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
toolName: string,
|
||||
args: string,
|
||||
description: string,
|
||||
|
|
@ -49,18 +51,28 @@ export async function invokeMockTool(
|
|||
content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.`
|
||||
}];
|
||||
|
||||
const { text } = await generateText({
|
||||
const { text, usage } = await generateText({
|
||||
model: openai(MODEL),
|
||||
messages,
|
||||
});
|
||||
logger.log(`generated text: ${text}`);
|
||||
|
||||
// track usage
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: MODEL,
|
||||
inputTokens: usage.promptTokens,
|
||||
outputTokens: usage.completionTokens,
|
||||
context: "agents_runtime.mock_tool",
|
||||
});
|
||||
|
||||
return text;
|
||||
}
|
||||
|
||||
// Helper to handle RAG tool calls
|
||||
export async function invokeRagTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
query: string,
|
||||
sourceIds: string[],
|
||||
|
|
@ -81,11 +93,21 @@ export async function invokeRagTool(
|
|||
logger.log(`k: ${k}`);
|
||||
|
||||
// Create embedding for question
|
||||
const { embedding } = await embed({
|
||||
const { embedding, usage } = await embed({
|
||||
model: embeddingModel,
|
||||
value: query,
|
||||
});
|
||||
|
||||
// track usage
|
||||
|
||||
// track usage
|
||||
usageTracker.track({
|
||||
type: "EMBEDDING_MODEL_USAGE",
|
||||
modelName: embeddingModel.modelId,
|
||||
tokens: usage.tokens,
|
||||
context: "agents_runtime.rag_tool.embedding_usage",
|
||||
});
|
||||
|
||||
// Fetch all data sources for this project
|
||||
const sources = await dataSourcesCollection.find({
|
||||
projectId: projectId,
|
||||
|
|
@ -154,6 +176,7 @@ export async function invokeRagTool(
|
|||
|
||||
export async function invokeWebhookTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
name: string,
|
||||
input: any,
|
||||
|
|
@ -233,6 +256,7 @@ export async function invokeWebhookTool(
|
|||
// Helper to handle MCP tool calls
|
||||
export async function invokeMcpTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
name: string,
|
||||
input: any,
|
||||
|
|
@ -269,6 +293,7 @@ export async function invokeMcpTool(
|
|||
// Helper to handle composio tool calls
|
||||
export async function invokeComposioTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
name: string,
|
||||
composioData: z.infer<typeof WorkflowTool>['composioData'] & {},
|
||||
|
|
@ -299,12 +324,21 @@ export async function invokeComposioTool(
|
|||
connectedAccountId: connectedAccountId,
|
||||
});
|
||||
logger.log(`composio tool result: ${JSON.stringify(result)}`);
|
||||
|
||||
// track usage
|
||||
usageTracker.track({
|
||||
type: "COMPOSIO_TOOL_USAGE",
|
||||
toolSlug: slug,
|
||||
context: "agents_runtime.composio_tool",
|
||||
});
|
||||
|
||||
return result.data;
|
||||
}
|
||||
|
||||
// Helper to create RAG tool
|
||||
export function createRagTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
config: z.infer<typeof WorkflowAgent>,
|
||||
projectId: string
|
||||
): Tool {
|
||||
|
|
@ -321,6 +355,7 @@ export function createRagTool(
|
|||
async execute(input: { query: string }) {
|
||||
const results = await invokeRagTool(
|
||||
logger,
|
||||
usageTracker,
|
||||
projectId,
|
||||
input.query,
|
||||
config.ragDataSources || [],
|
||||
|
|
@ -337,6 +372,7 @@ export function createRagTool(
|
|||
// Helper to create a mock tool
|
||||
export function createMockTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
config: z.infer<typeof WorkflowTool>,
|
||||
): Tool {
|
||||
return tool({
|
||||
|
|
@ -353,6 +389,7 @@ export function createMockTool(
|
|||
try {
|
||||
const result = await invokeMockTool(
|
||||
logger,
|
||||
usageTracker,
|
||||
config.name,
|
||||
JSON.stringify(input),
|
||||
config.description,
|
||||
|
|
@ -374,6 +411,7 @@ export function createMockTool(
|
|||
// Helper to create a webhook tool
|
||||
export function createWebhookTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
config: z.infer<typeof WorkflowTool>,
|
||||
projectId: string,
|
||||
): Tool {
|
||||
|
|
@ -391,7 +429,7 @@ export function createWebhookTool(
|
|||
},
|
||||
async execute(input: any) {
|
||||
try {
|
||||
const result = await invokeWebhookTool(logger, projectId, name, input);
|
||||
const result = await invokeWebhookTool(logger, usageTracker, projectId, name, input);
|
||||
return JSON.stringify({
|
||||
result,
|
||||
});
|
||||
|
|
@ -408,6 +446,7 @@ export function createWebhookTool(
|
|||
// Helper to create an mcp tool
|
||||
export function createMcpTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
config: z.infer<typeof WorkflowTool>,
|
||||
projectId: string
|
||||
): Tool {
|
||||
|
|
@ -425,7 +464,7 @@ export function createMcpTool(
|
|||
},
|
||||
async execute(input: any) {
|
||||
try {
|
||||
const result = await invokeMcpTool(logger, projectId, name, input, mcpServerName || '');
|
||||
const result = await invokeMcpTool(logger, usageTracker, projectId, name, input, mcpServerName || '');
|
||||
return JSON.stringify({
|
||||
result,
|
||||
});
|
||||
|
|
@ -442,6 +481,7 @@ export function createMcpTool(
|
|||
// Helper to create a composio tool
|
||||
export function createComposioTool(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
config: z.infer<typeof WorkflowTool>,
|
||||
projectId: string
|
||||
): Tool {
|
||||
|
|
@ -463,7 +503,7 @@ export function createComposioTool(
|
|||
},
|
||||
async execute(input: any) {
|
||||
try {
|
||||
const result = await invokeComposioTool(logger, projectId, name, composioData, input);
|
||||
const result = await invokeComposioTool(logger, usageTracker, projectId, name, composioData, input);
|
||||
return JSON.stringify({
|
||||
result,
|
||||
});
|
||||
|
|
@ -479,6 +519,7 @@ export function createComposioTool(
|
|||
|
||||
export function createTools(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
workflow: { tools: z.infer<typeof WorkflowTool>[] },
|
||||
toolConfig: Record<string, z.infer<typeof WorkflowTool>>,
|
||||
|
|
@ -492,16 +533,16 @@ export function createTools(
|
|||
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
|
||||
|
||||
if (config.mockTool) {
|
||||
tools[toolName] = createMockTool(logger, config);
|
||||
tools[toolName] = createMockTool(logger, usageTracker, config);
|
||||
toolLogger.log(`✓ created mock tool: ${toolName}`);
|
||||
} else if (config.isMcp) {
|
||||
tools[toolName] = createMcpTool(logger, config, projectId);
|
||||
tools[toolName] = createMcpTool(logger, usageTracker, config, projectId);
|
||||
toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`);
|
||||
} else if (config.isComposio) {
|
||||
tools[toolName] = createComposioTool(logger, config, projectId);
|
||||
tools[toolName] = createComposioTool(logger, usageTracker, config, projectId);
|
||||
toolLogger.log(`✓ created composio tool: ${toolName}`);
|
||||
} else {
|
||||
tools[toolName] = createWebhookTool(logger, config, projectId);
|
||||
tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId);
|
||||
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ import { ConnectedEntity, sanitizeTextWithMentions, Workflow, WorkflowAgent, Wor
|
|||
import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions";
|
||||
import { PrefixLogger } from "./utils";
|
||||
import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types";
|
||||
import { UsageTracker } from "./billing";
|
||||
|
||||
// Native handoff support
|
||||
import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs";
|
||||
import { PipelineStateManager } from "./pipeline-state-manager";
|
||||
|
|
@ -78,14 +80,6 @@ const openai = createOpenAI({
|
|||
baseURL: PROVIDER_BASE_URL,
|
||||
});
|
||||
|
||||
const ZUsage = z.object({
|
||||
tokens: z.object({
|
||||
total: z.number(),
|
||||
prompt: z.number(),
|
||||
completion: z.number(),
|
||||
}),
|
||||
});
|
||||
|
||||
const ZOutMessage = z.union([
|
||||
AssistantMessage,
|
||||
AssistantMessageWithToolCalls,
|
||||
|
|
@ -95,6 +89,7 @@ const ZOutMessage = z.union([
|
|||
// Helper to create an agent
|
||||
function createAgent(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
config: z.infer<typeof WorkflowAgent>,
|
||||
tools: Record<string, Tool>,
|
||||
|
|
@ -145,7 +140,7 @@ ${CHILD_TRANSFER_RELATED_INSTRUCTIONS}
|
|||
|
||||
// Add RAG tool if needed
|
||||
if (config.ragDataSources?.length) {
|
||||
const ragTool = createRagTool(logger, config, projectId);
|
||||
const ragTool = createRagTool(logger, usageTracker, config, projectId);
|
||||
agentTools.push(ragTool);
|
||||
|
||||
// update instructions to include RAG instructions
|
||||
|
|
@ -269,8 +264,8 @@ function getStartOfTurnAgentName(
|
|||
// Logs an event and then yields it
|
||||
async function* emitEvent(
|
||||
logger: PrefixLogger,
|
||||
event: z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>,
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
||||
event: z.infer<typeof ZOutMessage>,
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||
logger.log(`-> emitting event: ${JSON.stringify(event)}`);
|
||||
yield event;
|
||||
return;
|
||||
|
|
@ -321,30 +316,6 @@ class AgentTransferCounter {
|
|||
}
|
||||
}
|
||||
|
||||
class UsageTracker {
|
||||
private usage: {
|
||||
total: number;
|
||||
prompt: number;
|
||||
completion: number;
|
||||
} = { total: 0, prompt: 0, completion: 0 };
|
||||
|
||||
increment(total: number, prompt: number, completion: number): void {
|
||||
this.usage.total += total;
|
||||
this.usage.prompt += prompt;
|
||||
this.usage.completion += completion;
|
||||
}
|
||||
|
||||
get(): { total: number, prompt: number, completion: number } {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
asEvent(): z.infer<typeof ZUsage> {
|
||||
return {
|
||||
tokens: this.usage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) {
|
||||
logger = logger.child(`ensureSystemMessage`);
|
||||
|
||||
|
|
@ -396,7 +367,7 @@ function mapConfig(workflow: z.infer<typeof Workflow>): {
|
|||
return { agentConfig, toolConfig, promptConfig, pipelineConfig };
|
||||
}
|
||||
|
||||
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
||||
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||
// find the greeting prompt
|
||||
const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?';
|
||||
logger.log(`greeting turn: ${prompt}`);
|
||||
|
|
@ -408,15 +379,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof
|
|||
agentName: workflow.startAgent,
|
||||
responseType: 'external',
|
||||
});
|
||||
|
||||
// emit final usage information
|
||||
yield* emitEvent(logger, new UsageTracker().asEvent());
|
||||
}
|
||||
|
||||
|
||||
// Enhanced agent creation with native handoff support
|
||||
function createAgentsWithNativeHandoffs(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
workflow: z.infer<typeof Workflow>,
|
||||
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||
|
|
@ -447,6 +416,7 @@ function createAgentsWithNativeHandoffs(
|
|||
|
||||
const { agent, entities } = createAgent(
|
||||
logger,
|
||||
usageTracker,
|
||||
projectId,
|
||||
config,
|
||||
tools,
|
||||
|
|
@ -560,6 +530,7 @@ function createAgentsWithNativeHandoffs(
|
|||
// Legacy agent creation (existing implementation)
|
||||
function createAgentsLegacy(
|
||||
logger: PrefixLogger,
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
workflow: z.infer<typeof Workflow>,
|
||||
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||
|
|
@ -595,6 +566,7 @@ function createAgentsLegacy(
|
|||
|
||||
const { agent, entities } = createAgent(
|
||||
logger,
|
||||
usageTracker,
|
||||
projectId,
|
||||
config,
|
||||
tools,
|
||||
|
|
@ -763,12 +735,13 @@ function maybeInjectGiveUpControlInstructions(
|
|||
// Handle raw model stream events
|
||||
async function* handleRawModelStreamEvent(
|
||||
event: any,
|
||||
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||
agentName: string,
|
||||
turnMsgs: z.infer<typeof Message>[],
|
||||
usageTracker: UsageTracker,
|
||||
eventLogger: PrefixLogger,
|
||||
getAgentState?: (agentName: string) => AgentState
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||
if (event.data.type === 'response_done') {
|
||||
// Count tool calls (excluding transfer_to_* calls)
|
||||
const toolCallCount = event.data.response.output.filter(
|
||||
|
|
@ -809,12 +782,13 @@ async function* handleRawModelStreamEvent(
|
|||
}
|
||||
|
||||
// update usage information
|
||||
usageTracker.increment(
|
||||
event.data.response.usage.totalTokens,
|
||||
event.data.response.usage.inputTokens,
|
||||
event.data.response.usage.outputTokens
|
||||
);
|
||||
eventLogger.log(`updated usage information: ${JSON.stringify(usageTracker.get())}`);
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: agentConfig[agentName]?.model || "unknown",
|
||||
inputTokens: event.data.response.usage.inputTokens,
|
||||
outputTokens: event.data.response.usage.outputTokens,
|
||||
context: "agents_runtime.llm_usage",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -833,7 +807,7 @@ async function* handleNativeHandoffEvent(
|
|||
originalHandoffs: Record<string, any[]>,
|
||||
eventLogger: PrefixLogger,
|
||||
loopLogger: PrefixLogger
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string; shouldContinue?: boolean }> {
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string; shouldContinue?: boolean }> {
|
||||
eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
||||
|
||||
// skip if its the same agent
|
||||
|
|
@ -943,7 +917,7 @@ async function* handleHandoffEvent(
|
|||
originalHandoffs: Record<string, Agent[]>,
|
||||
eventLogger: PrefixLogger,
|
||||
loopLogger: PrefixLogger
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string }> {
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string }> {
|
||||
eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
||||
|
||||
// skip if its the same agent
|
||||
|
|
@ -1010,7 +984,7 @@ async function* handleToolCallResult(
|
|||
event: any,
|
||||
turnMsgs: z.infer<typeof Message>[],
|
||||
eventLogger: PrefixLogger
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||
const m: z.infer<typeof Message> = {
|
||||
role: 'tool',
|
||||
content: event.item.rawItem.output.text,
|
||||
|
|
@ -1039,7 +1013,7 @@ async function* handleMessageOutput(
|
|||
eventLogger: PrefixLogger,
|
||||
loopLogger: PrefixLogger,
|
||||
getAgentState: (agentName: string) => AgentState
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string | null; shouldContinue: boolean }> {
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string | null; shouldContinue: boolean }> {
|
||||
// check response visibility - could be an agent or pipeline
|
||||
const agentConfigObj = agentConfig[agentName];
|
||||
const pipelineConfigObj = pipelineConfig[agentName];
|
||||
|
|
@ -1243,7 +1217,8 @@ export async function* streamResponse(
|
|||
projectId: string,
|
||||
workflow: z.infer<typeof Workflow>,
|
||||
messages: z.infer<typeof Message>[],
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
||||
usageTracker: UsageTracker,
|
||||
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||
// Divider log for tracking agent loop start
|
||||
console.log('-------------------- AGENT LOOP START --------------------');
|
||||
// set up logging
|
||||
|
|
@ -1275,11 +1250,11 @@ export async function* streamResponse(
|
|||
logger.log(`initialized stack: ${JSON.stringify(stack)}`);
|
||||
|
||||
// create tools
|
||||
const tools = createTools(logger, projectId, workflow, toolConfig);
|
||||
const tools = createTools(logger, usageTracker, projectId, workflow, toolConfig);
|
||||
|
||||
// create agents with feature flag support
|
||||
const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy;
|
||||
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
|
||||
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, usageTracker, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
|
||||
|
||||
logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`);
|
||||
|
||||
|
|
@ -1296,7 +1271,6 @@ export async function* streamResponse(
|
|||
let agentName: string | null = startOfTurnAgentName;
|
||||
|
||||
// start the turn loop
|
||||
const usageTracker = new UsageTracker();
|
||||
const turnMsgs: z.infer<typeof Message>[] = [...messages];
|
||||
|
||||
// Initialize agent state tracking for tool call completion
|
||||
|
|
@ -1390,7 +1364,7 @@ export async function* streamResponse(
|
|||
|
||||
switch (event.type) {
|
||||
case 'raw_model_stream_event':
|
||||
yield* handleRawModelStreamEvent(event, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
|
||||
yield* handleRawModelStreamEvent(event, agentConfig, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
|
||||
break;
|
||||
|
||||
case 'run_item_stream_event':
|
||||
|
|
@ -1523,9 +1497,6 @@ export async function* streamResponse(
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
// emit usage information
|
||||
yield* emitEvent(logger, usageTracker.asEvent());
|
||||
}
|
||||
|
||||
// this is a sync version of streamResponse
|
||||
|
|
@ -1535,8 +1506,10 @@ export async function getResponse(
|
|||
messages: z.infer<typeof Message>[],
|
||||
): Promise<{
|
||||
messages: z.infer<typeof ZOutMessage>[],
|
||||
usage: z.infer<typeof ZUsage>,
|
||||
usage: any,
|
||||
}> {
|
||||
throw new Error("Not implemented!");
|
||||
/*
|
||||
const out: z.infer<typeof ZOutMessage>[] = [];
|
||||
let usage: z.infer<typeof ZUsage> = {
|
||||
tokens: {
|
||||
|
|
@ -1554,4 +1527,5 @@ export async function getResponse(
|
|||
}
|
||||
}
|
||||
return { messages: out, usage };
|
||||
*/
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import { WithStringId } from './types/types';
|
||||
import { z } from 'zod';
|
||||
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types';
|
||||
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse, UsageItem } from './types/billing_types';
|
||||
import { ObjectId } from 'mongodb';
|
||||
import { projectsCollection, usersCollection } from './mongodb';
|
||||
import { redirect } from 'next/navigation';
|
||||
|
|
@ -23,6 +23,20 @@ const GUEST_BILLING_CUSTOMER = {
|
|||
updatedAt: new Date().toISOString(),
|
||||
};
|
||||
|
||||
export class UsageTracker{
|
||||
private items: z.infer<typeof UsageItem>[] = [];
|
||||
|
||||
track(item: z.infer<typeof UsageItem>) {
|
||||
this.items.push(item);
|
||||
}
|
||||
|
||||
flush(): z.infer<typeof UsageItem>[] {
|
||||
const items = this.items;
|
||||
this.items = [];
|
||||
return items;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCustomerIdForProject(projectId: string): Promise<string> {
|
||||
const project = await projectsCollection.findOne({ _id: projectId });
|
||||
if (!project) {
|
||||
|
|
@ -111,6 +125,7 @@ export async function authorize(customerId: string, request: z.infer<typeof Auth
|
|||
}
|
||||
|
||||
export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) {
|
||||
console.log(`logging billing usage for customer ${customerId}`, JSON.stringify(request));
|
||||
const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import { COPILOT_MULTI_AGENT_EXAMPLE_1 } from "./example_multi_agent_1";
|
|||
import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow";
|
||||
import { USE_COMPOSIO_TOOLS } from "../feature_flags";
|
||||
import { composio, getTool } from "../composio/composio";
|
||||
import { UsageTracker } from "../billing";
|
||||
|
||||
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||
const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined;
|
||||
|
|
@ -119,7 +120,7 @@ ${JSON.stringify(simplifiedDataSources)}
|
|||
return prompt;
|
||||
}
|
||||
|
||||
async function searchRelevantTools(query: string): Promise<string> {
|
||||
async function searchRelevantTools(usageTracker: UsageTracker, query: string): Promise<string> {
|
||||
const logger = new PrefixLogger("copilot-search-tools");
|
||||
console.log("🔧 TOOL CALL: searchRelevantTools", { query });
|
||||
|
||||
|
|
@ -142,6 +143,13 @@ async function searchRelevantTools(query: string): Promise<string> {
|
|||
return 'No tools found!';
|
||||
}
|
||||
|
||||
// track composio search tool usage
|
||||
usageTracker.track({
|
||||
type: "COMPOSIO_TOOL_USAGE",
|
||||
toolSlug: "COMPOSIO_SEARCH_TOOLS",
|
||||
context: "copilot.search_relevant_tools",
|
||||
});
|
||||
|
||||
// parse results
|
||||
const result = composioToolSearchResponseSchema.safeParse(searchResult.data);
|
||||
if (!result.success) {
|
||||
|
|
@ -208,6 +216,7 @@ function updateLastUserMessage(
|
|||
}
|
||||
|
||||
export async function getEditAgentInstructionsResponse(
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
context: z.infer<typeof CopilotChatContext> | null,
|
||||
messages: z.infer<typeof CopilotMessage>[],
|
||||
|
|
@ -232,7 +241,7 @@ export async function getEditAgentInstructionsResponse(
|
|||
system: COPILOT_INSTRUCTIONS_EDIT_AGENT,
|
||||
messages: messages,
|
||||
}));
|
||||
const { object } = await generateObject({
|
||||
const { object, usage } = await generateObject({
|
||||
model: openai(COPILOT_MODEL),
|
||||
messages: [
|
||||
{
|
||||
|
|
@ -246,10 +255,20 @@ export async function getEditAgentInstructionsResponse(
|
|||
}),
|
||||
});
|
||||
|
||||
// log usage
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: COPILOT_MODEL,
|
||||
inputTokens: usage.promptTokens,
|
||||
outputTokens: usage.completionTokens,
|
||||
context: "copilot.llm_usage",
|
||||
});
|
||||
|
||||
return object.agent_instructions;
|
||||
}
|
||||
|
||||
export async function* streamMultiAgentResponse(
|
||||
usageTracker: UsageTracker,
|
||||
projectId: string,
|
||||
context: z.infer<typeof CopilotChatContext> | null,
|
||||
messages: z.infer<typeof CopilotMessage>[],
|
||||
|
|
@ -297,7 +316,7 @@ export async function* streamMultiAgentResponse(
|
|||
}),
|
||||
execute: async ({ query }: { query: string }) => {
|
||||
console.log("🎯 AI TOOL CALL: search_relevant_tools", { query });
|
||||
const result = await searchRelevantTools(query);
|
||||
const result = await searchRelevantTools(usageTracker, query);
|
||||
console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", {
|
||||
query,
|
||||
resultLength: result.length
|
||||
|
|
@ -341,6 +360,15 @@ export async function* streamMultiAgentResponse(
|
|||
toolCallId: event.toolCallId,
|
||||
result: event.result,
|
||||
};
|
||||
} else if (event.type === "step-finish") {
|
||||
// log usage
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: COPILOT_MODEL,
|
||||
inputTokens: event.usage.promptTokens,
|
||||
outputTokens: event.usage.completionTokens,
|
||||
context: "copilot.llm_usage",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,52 @@ import { z } from "zod";
|
|||
|
||||
export const SubscriptionPlan = z.enum(["free", "starter", "pro"]);
|
||||
|
||||
export const UsageType = z.enum([
|
||||
"copilot_requests",
|
||||
"agent_messages",
|
||||
"rag_tokens",
|
||||
export const UsageTypeKey = z.enum([
|
||||
"LLM_USAGE",
|
||||
"EMBEDDING_MODEL_USAGE",
|
||||
"COMPOSIO_TOOL_USAGE",
|
||||
"FIRECRAWL_SCRAPE_USAGE",
|
||||
]);
|
||||
|
||||
export const LLMUsage = z.object({
|
||||
type: z.literal(UsageTypeKey.Enum.LLM_USAGE),
|
||||
modelName: z.string(),
|
||||
inputTokens: z.number().positive(),
|
||||
outputTokens: z.number().positive(),
|
||||
context: z.string(),
|
||||
});
|
||||
|
||||
export const EmbeddingModelUsage = z.object({
|
||||
type: z.literal(UsageTypeKey.Enum.EMBEDDING_MODEL_USAGE),
|
||||
modelName: z.string(),
|
||||
tokens: z.number().positive(),
|
||||
context: z.string(),
|
||||
});
|
||||
|
||||
export const ComposioToolUsage = z.object({
|
||||
type: z.literal(UsageTypeKey.Enum.COMPOSIO_TOOL_USAGE),
|
||||
toolSlug: z.string(),
|
||||
context: z.string(),
|
||||
});
|
||||
|
||||
export const FirecrawlScrapeUsage = z.object({
|
||||
type: z.literal(UsageTypeKey.Enum.FIRECRAWL_SCRAPE_USAGE),
|
||||
context: z.string(),
|
||||
});
|
||||
|
||||
export const UsageItem = z.discriminatedUnion("type", [
|
||||
LLMUsage,
|
||||
EmbeddingModelUsage,
|
||||
ComposioToolUsage,
|
||||
FirecrawlScrapeUsage,
|
||||
]);
|
||||
|
||||
export const LogUsageRequest = z.object({
|
||||
items: z.array(UsageItem),
|
||||
});
|
||||
|
||||
export const CustomerUsageData = z.record(z.string(), z.number());
|
||||
|
||||
export const Customer = z.object({
|
||||
_id: z.string(),
|
||||
userId: z.string(),
|
||||
|
|
@ -19,36 +59,23 @@ export const Customer = z.object({
|
|||
createdAt: z.string().datetime(),
|
||||
updatedAt: z.string().datetime(),
|
||||
subscriptionPlanUpdatedAt: z.string().datetime().optional(),
|
||||
usage: z.record(UsageType, z.number()).optional(),
|
||||
usage: CustomerUsageData.optional(),
|
||||
usageUpdatedAt: z.string().datetime().optional(),
|
||||
});
|
||||
|
||||
export const LogUsageRequest = z.object({
|
||||
type: UsageType,
|
||||
amount: z.number().int().positive(),
|
||||
});
|
||||
creditsOverride: z.number().optional(),
|
||||
maxProjectsOverride: z.number().optional(),
|
||||
agentModelsOverride: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
export const AuthorizeRequest = z.discriminatedUnion("type", [
|
||||
z.object({
|
||||
"type": z.literal("use_credits"),
|
||||
}),
|
||||
z.object({
|
||||
"type": z.literal("create_project"),
|
||||
"data": z.object({
|
||||
"existingProjectCount": z.number(),
|
||||
}),
|
||||
}),
|
||||
z.object({
|
||||
"type": z.literal("enable_hosted_tool_server"),
|
||||
"data": z.object({
|
||||
"existingServerCount": z.number(),
|
||||
}),
|
||||
}),
|
||||
z.object({
|
||||
"type": z.literal("process_rag"),
|
||||
"data": z.object({}),
|
||||
}),
|
||||
z.object({
|
||||
"type": z.literal("copilot_request"),
|
||||
"data": z.object({}),
|
||||
}),
|
||||
z.object({
|
||||
"type": z.literal("agent_response"),
|
||||
"data": z.object({
|
||||
|
|
@ -63,10 +90,9 @@ export const AuthorizeResponse = z.object({
|
|||
});
|
||||
|
||||
export const UsageResponse = z.object({
|
||||
usage: z.record(UsageType, z.object({
|
||||
usage: z.number(),
|
||||
total: z.number(),
|
||||
})),
|
||||
sanctionedCredits: z.number(),
|
||||
availableCredits: z.number(),
|
||||
usage: CustomerUsageData,
|
||||
});
|
||||
|
||||
export const CustomerPortalSessionRequest = z.object({
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import crypto from 'crypto';
|
|||
import path from 'path';
|
||||
import { createOpenAI } from '@ai-sdk/openai';
|
||||
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
|
||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
||||
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||
import { BillingError } from '@/src/entities/errors/common';
|
||||
|
||||
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||
|
|
@ -75,7 +75,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
|
|||
}
|
||||
}
|
||||
|
||||
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }): Promise<number> {
|
||||
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }) {
|
||||
const logger = _logger
|
||||
.child(doc._id.toString())
|
||||
.child(doc.name);
|
||||
|
|
@ -95,7 +95,7 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
if (!USE_GEMINI_FILE_PARSING) {
|
||||
// Use OpenAI to extract text content
|
||||
logger.log("Extracting content using OpenAI");
|
||||
const { text } = await generateText({
|
||||
const { text, usage } = await generateText({
|
||||
model: openai(FILE_PARSING_MODEL),
|
||||
system: extractPrompt,
|
||||
messages: [
|
||||
|
|
@ -112,6 +112,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
],
|
||||
});
|
||||
markdown = text;
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: FILE_PARSING_MODEL,
|
||||
inputTokens: usage.promptTokens,
|
||||
outputTokens: usage.completionTokens,
|
||||
context: "rag.files.llm_usage",
|
||||
});
|
||||
} else {
|
||||
// Use Gemini to extract text content
|
||||
logger.log("Extracting content using Gemini");
|
||||
|
|
@ -127,6 +134,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
extractPrompt,
|
||||
]);
|
||||
markdown = result.response.text();
|
||||
usageTracker.track({
|
||||
type: "LLM_USAGE",
|
||||
modelName: FILE_PARSING_MODEL,
|
||||
inputTokens: result.response.usageMetadata?.promptTokenCount || 0,
|
||||
outputTokens: result.response.usageMetadata?.candidatesTokenCount || 0,
|
||||
context: "rag.files.llm_usage",
|
||||
});
|
||||
}
|
||||
|
||||
// split into chunks
|
||||
|
|
@ -139,6 +153,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
model: embeddingModel,
|
||||
values: splits.map((split) => split.pageContent)
|
||||
});
|
||||
usageTracker.track({
|
||||
type: "EMBEDDING_MODEL_USAGE",
|
||||
modelName: embeddingModel.modelId,
|
||||
tokens: usage.tokens,
|
||||
context: "rag.files.embedding_usage",
|
||||
});
|
||||
|
||||
// store embeddings in qdrant
|
||||
logger.log("Storing embeddings in Qdrant");
|
||||
|
|
@ -170,8 +190,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
lastUpdatedAt: new Date().toISOString(),
|
||||
}
|
||||
});
|
||||
|
||||
return usage.tokens;
|
||||
}
|
||||
|
||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||
|
|
@ -339,8 +357,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
// authorize with billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
const authResponse = await authorize(billingCustomerId, {
|
||||
type: "process_rag",
|
||||
data: {},
|
||||
type: "use_credits",
|
||||
});
|
||||
|
||||
if ('error' in authResponse) {
|
||||
|
|
@ -349,16 +366,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
}
|
||||
|
||||
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
|
||||
const usageTracker = new UsageTracker();
|
||||
try {
|
||||
const usedTokens = await runProcessPipeline(logger, job, ldoc);
|
||||
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
type: "rag_tokens",
|
||||
amount: usedTokens,
|
||||
});
|
||||
}
|
||||
await runProcessPipeline(logger, usageTracker, job, ldoc);
|
||||
} catch (e: any) {
|
||||
errors = true;
|
||||
logger.log("Error processing doc:", e);
|
||||
|
|
@ -371,6 +381,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
error: e.message,
|
||||
}
|
||||
});
|
||||
} finally {
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import { qdrantClient } from '../lib/qdrant';
|
|||
import { PrefixLogger } from "../lib/utils";
|
||||
import crypto from 'crypto';
|
||||
import { USE_BILLING } from '../lib/feature_flags';
|
||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
||||
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||
import { BillingError } from '@/src/entities/errors/common';
|
||||
|
||||
const splitter = new RecursiveCharacterTextSplitter({
|
||||
|
|
@ -23,7 +23,7 @@ const second = 1000;
|
|||
const minute = 60 * second;
|
||||
const hour = 60 * minute;
|
||||
|
||||
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
|
||||
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
|
||||
const logger = _logger
|
||||
.child(doc._id.toString())
|
||||
.child(doc.name);
|
||||
|
|
@ -42,6 +42,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
model: embeddingModel,
|
||||
values: splits.map((split) => split.pageContent)
|
||||
});
|
||||
usageTracker.track({
|
||||
type: "EMBEDDING_MODEL_USAGE",
|
||||
modelName: embeddingModel.modelId,
|
||||
tokens: usage.tokens,
|
||||
context: "rag.text.embedding_usage",
|
||||
});
|
||||
|
||||
// store embeddings in qdrant
|
||||
logger.log("Storing embeddings in Qdrant");
|
||||
|
|
@ -73,8 +79,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
|||
lastUpdatedAt: new Date().toISOString(),
|
||||
}
|
||||
});
|
||||
|
||||
return usage.tokens;
|
||||
}
|
||||
|
||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||
|
|
@ -241,8 +245,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
// authorize with billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
const authResponse = await authorize(billingCustomerId, {
|
||||
type: "process_rag",
|
||||
data: {}
|
||||
type: "use_credits",
|
||||
});
|
||||
|
||||
if ('error' in authResponse) {
|
||||
|
|
@ -250,16 +253,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
}
|
||||
}
|
||||
|
||||
const usageTracker = new UsageTracker();
|
||||
try {
|
||||
const usedTokens = await runProcessPipeline(logger, job, doc);
|
||||
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
type: "rag_tokens",
|
||||
amount: usedTokens,
|
||||
});
|
||||
}
|
||||
await runProcessPipeline(logger, usageTracker, job, doc);
|
||||
} catch (e: any) {
|
||||
errors = true;
|
||||
logger.log("Error processing doc:", e);
|
||||
|
|
@ -272,6 +268,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
error: e.message,
|
||||
}
|
||||
});
|
||||
} finally {
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import { qdrantClient } from '../lib/qdrant';
|
|||
import { PrefixLogger } from "../lib/utils";
|
||||
import crypto from 'crypto';
|
||||
import { USE_BILLING } from '../lib/feature_flags';
|
||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
||||
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||
import { BillingError } from '@/src/entities/errors/common';
|
||||
|
||||
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });
|
||||
|
|
@ -41,7 +41,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
|
|||
}
|
||||
}
|
||||
|
||||
async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
|
||||
async function runScrapePipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
|
||||
const logger = _logger
|
||||
.child(doc._id.toString())
|
||||
.child(doc.name);
|
||||
|
|
@ -62,6 +62,10 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
|||
}
|
||||
return scrapeResult;
|
||||
}, 3); // Retry up to 3 times
|
||||
usageTracker.track({
|
||||
type: "FIRECRAWL_SCRAPE_USAGE",
|
||||
context: "rag.urls.firecrawl_scrape",
|
||||
});
|
||||
|
||||
// split into chunks
|
||||
logger.log("Splitting into chunks");
|
||||
|
|
@ -73,6 +77,12 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
|||
model: embeddingModel,
|
||||
values: splits.map((split) => split.pageContent)
|
||||
});
|
||||
usageTracker.track({
|
||||
type: "EMBEDDING_MODEL_USAGE",
|
||||
modelName: embeddingModel.modelId,
|
||||
tokens: usage.tokens,
|
||||
context: "rag.urls.embedding_usage",
|
||||
});
|
||||
|
||||
// store embeddings in qdrant
|
||||
logger.log("Storing embeddings in Qdrant");
|
||||
|
|
@ -104,8 +114,6 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
|||
lastUpdatedAt: new Date().toISOString(),
|
||||
}
|
||||
});
|
||||
|
||||
return usage.tokens;
|
||||
}
|
||||
|
||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||
|
|
@ -273,8 +281,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
// authorize with billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
const authResponse = await authorize(billingCustomerId, {
|
||||
type: "process_rag",
|
||||
data: {}
|
||||
type: "use_credits",
|
||||
});
|
||||
|
||||
if ('error' in authResponse) {
|
||||
|
|
@ -282,16 +289,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
}
|
||||
}
|
||||
|
||||
const usageTracker = new UsageTracker();
|
||||
try {
|
||||
const usedTokens = await runScrapePipeline(logger, job, doc);
|
||||
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
type: "rag_tokens",
|
||||
amount: usedTokens,
|
||||
});
|
||||
}
|
||||
await runScrapePipeline(logger, usageTracker, job, doc);
|
||||
} catch (e: any) {
|
||||
errors = true;
|
||||
logger.log("Error processing doc:", e);
|
||||
|
|
@ -304,6 +304,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
|||
error: e.message,
|
||||
}
|
||||
});
|
||||
} finally {
|
||||
// log usage in billing
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn";
|
||||
import { USE_BILLING } from "@/app/lib/feature_flags";
|
||||
import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing";
|
||||
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
|
||||
import { NotFoundError } from '@/src/entities/errors/common';
|
||||
import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface";
|
||||
import { streamResponse } from "@/app/lib/agents";
|
||||
|
|
@ -69,6 +69,21 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
|||
if (USE_BILLING) {
|
||||
// get billing customer id for project
|
||||
billingCustomerId = await getCustomerIdForProject(projectId);
|
||||
|
||||
// validate enough credits
|
||||
const result = await authorize(billingCustomerId, {
|
||||
type: "use_credits"
|
||||
});
|
||||
if (!result.success) {
|
||||
yield {
|
||||
type: "error",
|
||||
error: result.error || 'Billing error',
|
||||
isBillingError: true,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
// validate model usage
|
||||
const agentModels = conversation.workflow.agents.reduce((acc, agent) => {
|
||||
acc.push(agent.model);
|
||||
return acc;
|
||||
|
|
@ -111,46 +126,50 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
|||
conversation.workflow.mockTools = data.input.mockTools;
|
||||
}
|
||||
|
||||
// init usage tracker
|
||||
const usageTracker = new UsageTracker();
|
||||
|
||||
// call agents runtime and handle generated messages
|
||||
const outputMessages: z.infer<typeof Message>[] = [];
|
||||
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages)) {
|
||||
// handle msg events
|
||||
if ("role" in event) {
|
||||
// collect generated message
|
||||
const msg = {
|
||||
...event,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
outputMessages.push(msg);
|
||||
try {
|
||||
const outputMessages: z.infer<typeof Message>[] = [];
|
||||
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages, usageTracker)) {
|
||||
// handle msg events
|
||||
if ("role" in event) {
|
||||
// collect generated message
|
||||
const msg = {
|
||||
...event,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
outputMessages.push(msg);
|
||||
|
||||
// yield event
|
||||
yield {
|
||||
type: "message",
|
||||
data: msg,
|
||||
};
|
||||
} else {
|
||||
// save turn data
|
||||
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
|
||||
reason: data.reason,
|
||||
input: data.input,
|
||||
output: outputMessages,
|
||||
});
|
||||
|
||||
// yield event
|
||||
yield {
|
||||
type: "done",
|
||||
turn,
|
||||
conversationId,
|
||||
// yield event
|
||||
yield {
|
||||
type: "message",
|
||||
data: msg,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log billing usage
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
type: "agent_messages",
|
||||
amount: outputMessages.length,
|
||||
// save turn data
|
||||
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
|
||||
reason: data.reason,
|
||||
input: data.input,
|
||||
output: outputMessages,
|
||||
});
|
||||
|
||||
// yield event
|
||||
yield {
|
||||
type: "done",
|
||||
turn,
|
||||
conversationId,
|
||||
}
|
||||
} finally {
|
||||
// Log billing usage
|
||||
if (USE_BILLING && billingCustomerId) {
|
||||
await logUsage(billingCustomerId, {
|
||||
items: usageTracker.flush(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue