mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-09 15:22:39 +02:00
billing + credits updates (#202)
This commit is contained in:
parent
852e02e49e
commit
eccfb4748f
13 changed files with 497 additions and 229 deletions
|
|
@ -15,6 +15,7 @@ import { WithStringId } from "../lib/types/types";
|
||||||
import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot";
|
import { getEditAgentInstructionsResponse } from "../lib/copilot/copilot";
|
||||||
import { container } from "@/di/container";
|
import { container } from "@/di/container";
|
||||||
import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface";
|
import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface";
|
||||||
|
import { UsageTracker } from "../lib/billing";
|
||||||
|
|
||||||
const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy');
|
const usageQuotaPolicy = container.resolve<IUsageQuotaPolicy>('usageQuotaPolicy');
|
||||||
|
|
||||||
|
|
@ -32,8 +33,7 @@ export async function getCopilotResponseStream(
|
||||||
|
|
||||||
// Check billing authorization
|
// Check billing authorization
|
||||||
const authResponse = await authorizeUserAction({
|
const authResponse = await authorizeUserAction({
|
||||||
type: 'copilot_request',
|
type: 'use_credits',
|
||||||
data: {},
|
|
||||||
});
|
});
|
||||||
if (!authResponse.success) {
|
if (!authResponse.success) {
|
||||||
return { billingError: authResponse.error || 'Billing error' };
|
return { billingError: authResponse.error || 'Billing error' };
|
||||||
|
|
@ -75,8 +75,7 @@ export async function getCopilotAgentInstructions(
|
||||||
|
|
||||||
// Check billing authorization
|
// Check billing authorization
|
||||||
const authResponse = await authorizeUserAction({
|
const authResponse = await authorizeUserAction({
|
||||||
type: 'copilot_request',
|
type: 'use_credits',
|
||||||
data: {},
|
|
||||||
});
|
});
|
||||||
if (!authResponse.success) {
|
if (!authResponse.success) {
|
||||||
return { billingError: authResponse.error || 'Billing error' };
|
return { billingError: authResponse.error || 'Billing error' };
|
||||||
|
|
@ -93,8 +92,11 @@ export async function getCopilotAgentInstructions(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
|
|
||||||
// call copilot api
|
// call copilot api
|
||||||
const agent_instructions = await getEditAgentInstructionsResponse(
|
const agent_instructions = await getEditAgentInstructionsResponse(
|
||||||
|
usageTracker,
|
||||||
projectId,
|
projectId,
|
||||||
request.context,
|
request.context,
|
||||||
request.messages,
|
request.messages,
|
||||||
|
|
@ -104,8 +106,7 @@ export async function getCopilotAgentInstructions(
|
||||||
// log the billing usage
|
// log the billing usage
|
||||||
if (USE_BILLING) {
|
if (USE_BILLING) {
|
||||||
await logUsage({
|
await logUsage({
|
||||||
type: 'copilot_requests',
|
items: usageTracker.flush(),
|
||||||
amount: 1,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { getCustomerIdForProject, logUsage } from "@/app/lib/billing";
|
import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
|
||||||
import { USE_BILLING } from "@/app/lib/feature_flags";
|
import { USE_BILLING } from "@/app/lib/feature_flags";
|
||||||
import { redisClient } from "@/app/lib/redis";
|
import { redisClient } from "@/app/lib/redis";
|
||||||
import { CopilotAPIRequest } from "@/app/lib/types/copilot_types";
|
import { CopilotAPIRequest } from "@/app/lib/types/copilot_types";
|
||||||
|
|
@ -21,6 +21,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
||||||
billingCustomerId = await getCustomerIdForProject(projectId);
|
billingCustomerId = await getCustomerIdForProject(projectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
const encoder = new TextEncoder();
|
const encoder = new TextEncoder();
|
||||||
let messageCount = 0;
|
let messageCount = 0;
|
||||||
|
|
||||||
|
|
@ -29,6 +30,7 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
||||||
try {
|
try {
|
||||||
// Iterate over the copilot stream generator
|
// Iterate over the copilot stream generator
|
||||||
for await (const event of streamMultiAgentResponse(
|
for await (const event of streamMultiAgentResponse(
|
||||||
|
usageTracker,
|
||||||
projectId,
|
projectId,
|
||||||
context,
|
context,
|
||||||
messages,
|
messages,
|
||||||
|
|
@ -49,21 +51,20 @@ export async function GET(request: Request, props: { params: Promise<{ streamId:
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.close();
|
controller.close();
|
||||||
|
} catch (error) {
|
||||||
// increment copilot request count in billing
|
console.error('Error processing copilot stream:', error);
|
||||||
|
controller.error(error);
|
||||||
|
} finally {
|
||||||
|
// log copilot usage
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
try {
|
try {
|
||||||
await logUsage(billingCustomerId, {
|
await logUsage(billingCustomerId, {
|
||||||
type: "copilot_requests",
|
items: usageTracker.flush(),
|
||||||
amount: 1,
|
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error logging usage", error);
|
console.error("Error logging usage", error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
|
||||||
console.error('Error processing copilot stream:', error);
|
|
||||||
controller.error(error);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -220,10 +220,10 @@ export async function POST(
|
||||||
// log billing usage
|
// log billing usage
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length;
|
const agentMessageCount = convertedResponseMessages.filter(m => m.role === 'assistant').length;
|
||||||
await logUsage(billingCustomerId, {
|
// await logUsage(billingCustomerId, {
|
||||||
type: 'agent_messages',
|
// type: 'agent_messages',
|
||||||
amount: agentMessageCount,
|
// amount: agentMessageCount,
|
||||||
});
|
// });
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.log(`Turn processing completed successfully`);
|
logger.log(`Turn processing completed successfully`);
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import { Progress, Badge, Chip } from "@heroui/react";
|
import { Progress, Badge, Chip } from "@heroui/react";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Label } from "@/app/lib/components/label";
|
import { Label } from "@/app/lib/components/label";
|
||||||
import { Customer, UsageResponse, UsageType } from "@/app/lib/types/billing_types";
|
import { Customer, UsageResponse } from "@/app/lib/types/billing_types";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { tokens } from "@/app/styles/design-tokens";
|
import { tokens } from "@/app/styles/design-tokens";
|
||||||
import { SectionHeading } from "@/components/ui/section-heading";
|
import { SectionHeading } from "@/components/ui/section-heading";
|
||||||
|
|
@ -47,6 +47,15 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
|
||||||
const displayStatus = getDisplayStatus(customer.subscriptionStatus);
|
const displayStatus = getDisplayStatus(customer.subscriptionStatus);
|
||||||
const planInfo = planDetails[plan];
|
const planInfo = planDetails[plan];
|
||||||
|
|
||||||
|
// Prepare usage metrics data
|
||||||
|
const usageData = Object.entries(usage.usage)
|
||||||
|
.map(([type, credits]) => ({
|
||||||
|
type,
|
||||||
|
credits,
|
||||||
|
totalUsedCredits: usage.sanctionedCredits - usage.availableCredits
|
||||||
|
}))
|
||||||
|
.sort((a, b) => b.credits - a.credits);
|
||||||
|
|
||||||
async function handleManageSubscription() {
|
async function handleManageSubscription() {
|
||||||
const returnUrl = new URL('/billing/callback', window.location.origin);
|
const returnUrl = new URL('/billing/callback', window.location.origin);
|
||||||
returnUrl.searchParams.set('redirect', window.location.href);
|
returnUrl.searchParams.set('redirect', window.location.href);
|
||||||
|
|
@ -109,19 +118,143 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
|
|
||||||
{/* Usage Metrics Panel */}
|
{/* Credits Overview Panel */}
|
||||||
<section className="card">
|
<section className="card">
|
||||||
<div className="px-4 pt-4 pb-6">
|
<div className="px-4 pt-4 pb-6">
|
||||||
<SectionHeading>
|
<SectionHeading>
|
||||||
Usage Metrics
|
Credits Overview
|
||||||
</SectionHeading>
|
</SectionHeading>
|
||||||
</div>
|
</div>
|
||||||
<HorizontalDivider />
|
<HorizontalDivider />
|
||||||
<div className="p-6 space-y-6">
|
<div className="p-6 space-y-6">
|
||||||
{Object.entries(usage.usage).map(([type, { usage: used, total }]) => {
|
<div className="grid grid-cols-1 md:grid-cols-3 gap-6">
|
||||||
const usageType = type as z.infer<typeof UsageType>;
|
<div className="space-y-2">
|
||||||
const percentage = Math.min((used / total) * 100, 100);
|
<Label label="Sanctioned Credits" />
|
||||||
const isOverLimit = used > total;
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.lg,
|
||||||
|
tokens.typography.weights.semibold,
|
||||||
|
tokens.colors.light.text.primary,
|
||||||
|
tokens.colors.dark.text.primary
|
||||||
|
)}>
|
||||||
|
{usage.sanctionedCredits.toLocaleString()}
|
||||||
|
</p>
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
tokens.colors.light.text.secondary,
|
||||||
|
tokens.colors.dark.text.secondary
|
||||||
|
)}>
|
||||||
|
Total credits allocated to your plan
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label label="Used Credits" />
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.lg,
|
||||||
|
tokens.typography.weights.semibold,
|
||||||
|
tokens.colors.light.text.primary,
|
||||||
|
tokens.colors.dark.text.primary
|
||||||
|
)}>
|
||||||
|
{(usage.sanctionedCredits - usage.availableCredits).toLocaleString()}
|
||||||
|
</p>
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
tokens.colors.light.text.secondary,
|
||||||
|
tokens.colors.dark.text.secondary
|
||||||
|
)}>
|
||||||
|
Credits consumed so far
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label label="Available Credits" />
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.lg,
|
||||||
|
tokens.typography.weights.semibold,
|
||||||
|
usage.availableCredits < 0 ? "text-red-500" : clsx(
|
||||||
|
tokens.colors.light.text.primary,
|
||||||
|
tokens.colors.dark.text.primary
|
||||||
|
)
|
||||||
|
)}>
|
||||||
|
{usage.availableCredits.toLocaleString()}
|
||||||
|
</p>
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
tokens.colors.light.text.secondary,
|
||||||
|
tokens.colors.dark.text.secondary
|
||||||
|
)}>
|
||||||
|
Credits remaining for use
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Warning for negative credits */}
|
||||||
|
{usage.availableCredits < 0 && (
|
||||||
|
<div className="p-4 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg">
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
"text-red-700 dark:text-red-300"
|
||||||
|
)}>
|
||||||
|
⚠️ You have exceeded your credit limit. Please upgrade your plan or contact support to avoid service interruptions.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Warning for high credit usage (>80%) */}
|
||||||
|
{usage.availableCredits >= 0 && ((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) > 0.8 && (
|
||||||
|
<div className="p-4 bg-yellow-50 dark:bg-yellow-900/20 border border-yellow-200 dark:border-yellow-800 rounded-lg">
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
"text-yellow-700 dark:text-yellow-300"
|
||||||
|
)}>
|
||||||
|
⚠️ You have used more than 80% of your credits. Consider upgrading your plan to avoid interruptions.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Credits Progress Bar */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<Label label="Credits Usage" />
|
||||||
|
<span className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
tokens.colors.light.text.secondary,
|
||||||
|
tokens.colors.dark.text.secondary
|
||||||
|
)}>
|
||||||
|
{Math.round(((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100)}%
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<Progress
|
||||||
|
size="lg"
|
||||||
|
value={((usage.sanctionedCredits - usage.availableCredits) / usage.sanctionedCredits) * 100}
|
||||||
|
color={usage.availableCredits < 0 ? "danger" : "primary"}
|
||||||
|
className="h-4"
|
||||||
|
aria-label="Credits usage"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
{/* Usage Metrics Panel */}
|
||||||
|
<section className="card">
|
||||||
|
<div className="px-4 pt-4 pb-6">
|
||||||
|
<SectionHeading>
|
||||||
|
Usage data
|
||||||
|
</SectionHeading>
|
||||||
|
</div>
|
||||||
|
<HorizontalDivider />
|
||||||
|
<div className="p-6 space-y-6">
|
||||||
|
{usageData.length === 0 ? (
|
||||||
|
<div className="text-center py-8">
|
||||||
|
<p className={clsx(
|
||||||
|
tokens.typography.sizes.sm,
|
||||||
|
tokens.colors.light.text.secondary,
|
||||||
|
tokens.colors.dark.text.secondary
|
||||||
|
)}>
|
||||||
|
No usage data yet
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
usageData.map(({ type, credits, totalUsedCredits }) => {
|
||||||
|
const percentage = totalUsedCredits > 0 ? (credits / totalUsedCredits) * 100 : 0;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={type} className="space-y-2">
|
<div key={type} className="space-y-2">
|
||||||
|
|
@ -133,24 +266,27 @@ export function BillingPage({ customer, usage }: BillingPageProps) {
|
||||||
tokens.colors.light.text.secondary,
|
tokens.colors.light.text.secondary,
|
||||||
tokens.colors.dark.text.secondary
|
tokens.colors.dark.text.secondary
|
||||||
)}>
|
)}>
|
||||||
{used.toLocaleString()} / {total.toLocaleString()}
|
{credits.toLocaleString()} credits
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
{isOverLimit && (
|
<span className={clsx(
|
||||||
<Badge color="danger" variant="flat">
|
tokens.typography.sizes.sm,
|
||||||
Over Limit
|
tokens.colors.light.text.secondary,
|
||||||
</Badge>
|
tokens.colors.dark.text.secondary
|
||||||
)}
|
)}>
|
||||||
|
{Math.round(percentage)}%
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<Progress
|
<Progress
|
||||||
value={percentage}
|
value={percentage}
|
||||||
color={isOverLimit ? "danger" : "primary"}
|
color="default"
|
||||||
className="h-2"
|
className="h-2"
|
||||||
aria-label={`${type} usage`}
|
aria-label={`${type} credits usage`}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
})}
|
})
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ import { qdrantClient } from '../lib/qdrant';
|
||||||
import { EmbeddingRecord } from "./types/datasource_types";
|
import { EmbeddingRecord } from "./types/datasource_types";
|
||||||
import { WorkflowAgent, WorkflowTool } from "./types/workflow_types";
|
import { WorkflowAgent, WorkflowTool } from "./types/workflow_types";
|
||||||
import { PrefixLogger } from "./utils";
|
import { PrefixLogger } from "./utils";
|
||||||
|
import { UsageTracker } from "./billing";
|
||||||
|
|
||||||
// Provider configuration
|
// Provider configuration
|
||||||
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||||
|
|
@ -30,6 +31,7 @@ const openai = createOpenAI({
|
||||||
// Helper to handle mock tool responses
|
// Helper to handle mock tool responses
|
||||||
export async function invokeMockTool(
|
export async function invokeMockTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
toolName: string,
|
toolName: string,
|
||||||
args: string,
|
args: string,
|
||||||
description: string,
|
description: string,
|
||||||
|
|
@ -49,18 +51,28 @@ export async function invokeMockTool(
|
||||||
content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.`
|
content: `Generate a realistic response for the tool '${toolName}' with these parameters: ${args}. The response should be concise and focused on what the tool would actually return.`
|
||||||
}];
|
}];
|
||||||
|
|
||||||
const { text } = await generateText({
|
const { text, usage } = await generateText({
|
||||||
model: openai(MODEL),
|
model: openai(MODEL),
|
||||||
messages,
|
messages,
|
||||||
});
|
});
|
||||||
logger.log(`generated text: ${text}`);
|
logger.log(`generated text: ${text}`);
|
||||||
|
|
||||||
|
// track usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "LLM_USAGE",
|
||||||
|
modelName: MODEL,
|
||||||
|
inputTokens: usage.promptTokens,
|
||||||
|
outputTokens: usage.completionTokens,
|
||||||
|
context: "agents_runtime.mock_tool",
|
||||||
|
});
|
||||||
|
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to handle RAG tool calls
|
// Helper to handle RAG tool calls
|
||||||
export async function invokeRagTool(
|
export async function invokeRagTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
query: string,
|
query: string,
|
||||||
sourceIds: string[],
|
sourceIds: string[],
|
||||||
|
|
@ -81,11 +93,21 @@ export async function invokeRagTool(
|
||||||
logger.log(`k: ${k}`);
|
logger.log(`k: ${k}`);
|
||||||
|
|
||||||
// Create embedding for question
|
// Create embedding for question
|
||||||
const { embedding } = await embed({
|
const { embedding, usage } = await embed({
|
||||||
model: embeddingModel,
|
model: embeddingModel,
|
||||||
value: query,
|
value: query,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// track usage
|
||||||
|
|
||||||
|
// track usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "EMBEDDING_MODEL_USAGE",
|
||||||
|
modelName: embeddingModel.modelId,
|
||||||
|
tokens: usage.tokens,
|
||||||
|
context: "agents_runtime.rag_tool.embedding_usage",
|
||||||
|
});
|
||||||
|
|
||||||
// Fetch all data sources for this project
|
// Fetch all data sources for this project
|
||||||
const sources = await dataSourcesCollection.find({
|
const sources = await dataSourcesCollection.find({
|
||||||
projectId: projectId,
|
projectId: projectId,
|
||||||
|
|
@ -154,6 +176,7 @@ export async function invokeRagTool(
|
||||||
|
|
||||||
export async function invokeWebhookTool(
|
export async function invokeWebhookTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
name: string,
|
name: string,
|
||||||
input: any,
|
input: any,
|
||||||
|
|
@ -233,6 +256,7 @@ export async function invokeWebhookTool(
|
||||||
// Helper to handle MCP tool calls
|
// Helper to handle MCP tool calls
|
||||||
export async function invokeMcpTool(
|
export async function invokeMcpTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
name: string,
|
name: string,
|
||||||
input: any,
|
input: any,
|
||||||
|
|
@ -269,6 +293,7 @@ export async function invokeMcpTool(
|
||||||
// Helper to handle composio tool calls
|
// Helper to handle composio tool calls
|
||||||
export async function invokeComposioTool(
|
export async function invokeComposioTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
name: string,
|
name: string,
|
||||||
composioData: z.infer<typeof WorkflowTool>['composioData'] & {},
|
composioData: z.infer<typeof WorkflowTool>['composioData'] & {},
|
||||||
|
|
@ -299,12 +324,21 @@ export async function invokeComposioTool(
|
||||||
connectedAccountId: connectedAccountId,
|
connectedAccountId: connectedAccountId,
|
||||||
});
|
});
|
||||||
logger.log(`composio tool result: ${JSON.stringify(result)}`);
|
logger.log(`composio tool result: ${JSON.stringify(result)}`);
|
||||||
|
|
||||||
|
// track usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "COMPOSIO_TOOL_USAGE",
|
||||||
|
toolSlug: slug,
|
||||||
|
context: "agents_runtime.composio_tool",
|
||||||
|
});
|
||||||
|
|
||||||
return result.data;
|
return result.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to create RAG tool
|
// Helper to create RAG tool
|
||||||
export function createRagTool(
|
export function createRagTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
config: z.infer<typeof WorkflowAgent>,
|
config: z.infer<typeof WorkflowAgent>,
|
||||||
projectId: string
|
projectId: string
|
||||||
): Tool {
|
): Tool {
|
||||||
|
|
@ -321,6 +355,7 @@ export function createRagTool(
|
||||||
async execute(input: { query: string }) {
|
async execute(input: { query: string }) {
|
||||||
const results = await invokeRagTool(
|
const results = await invokeRagTool(
|
||||||
logger,
|
logger,
|
||||||
|
usageTracker,
|
||||||
projectId,
|
projectId,
|
||||||
input.query,
|
input.query,
|
||||||
config.ragDataSources || [],
|
config.ragDataSources || [],
|
||||||
|
|
@ -337,6 +372,7 @@ export function createRagTool(
|
||||||
// Helper to create a mock tool
|
// Helper to create a mock tool
|
||||||
export function createMockTool(
|
export function createMockTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
config: z.infer<typeof WorkflowTool>,
|
config: z.infer<typeof WorkflowTool>,
|
||||||
): Tool {
|
): Tool {
|
||||||
return tool({
|
return tool({
|
||||||
|
|
@ -353,6 +389,7 @@ export function createMockTool(
|
||||||
try {
|
try {
|
||||||
const result = await invokeMockTool(
|
const result = await invokeMockTool(
|
||||||
logger,
|
logger,
|
||||||
|
usageTracker,
|
||||||
config.name,
|
config.name,
|
||||||
JSON.stringify(input),
|
JSON.stringify(input),
|
||||||
config.description,
|
config.description,
|
||||||
|
|
@ -374,6 +411,7 @@ export function createMockTool(
|
||||||
// Helper to create a webhook tool
|
// Helper to create a webhook tool
|
||||||
export function createWebhookTool(
|
export function createWebhookTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
config: z.infer<typeof WorkflowTool>,
|
config: z.infer<typeof WorkflowTool>,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
): Tool {
|
): Tool {
|
||||||
|
|
@ -391,7 +429,7 @@ export function createWebhookTool(
|
||||||
},
|
},
|
||||||
async execute(input: any) {
|
async execute(input: any) {
|
||||||
try {
|
try {
|
||||||
const result = await invokeWebhookTool(logger, projectId, name, input);
|
const result = await invokeWebhookTool(logger, usageTracker, projectId, name, input);
|
||||||
return JSON.stringify({
|
return JSON.stringify({
|
||||||
result,
|
result,
|
||||||
});
|
});
|
||||||
|
|
@ -408,6 +446,7 @@ export function createWebhookTool(
|
||||||
// Helper to create an mcp tool
|
// Helper to create an mcp tool
|
||||||
export function createMcpTool(
|
export function createMcpTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
config: z.infer<typeof WorkflowTool>,
|
config: z.infer<typeof WorkflowTool>,
|
||||||
projectId: string
|
projectId: string
|
||||||
): Tool {
|
): Tool {
|
||||||
|
|
@ -425,7 +464,7 @@ export function createMcpTool(
|
||||||
},
|
},
|
||||||
async execute(input: any) {
|
async execute(input: any) {
|
||||||
try {
|
try {
|
||||||
const result = await invokeMcpTool(logger, projectId, name, input, mcpServerName || '');
|
const result = await invokeMcpTool(logger, usageTracker, projectId, name, input, mcpServerName || '');
|
||||||
return JSON.stringify({
|
return JSON.stringify({
|
||||||
result,
|
result,
|
||||||
});
|
});
|
||||||
|
|
@ -442,6 +481,7 @@ export function createMcpTool(
|
||||||
// Helper to create a composio tool
|
// Helper to create a composio tool
|
||||||
export function createComposioTool(
|
export function createComposioTool(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
config: z.infer<typeof WorkflowTool>,
|
config: z.infer<typeof WorkflowTool>,
|
||||||
projectId: string
|
projectId: string
|
||||||
): Tool {
|
): Tool {
|
||||||
|
|
@ -463,7 +503,7 @@ export function createComposioTool(
|
||||||
},
|
},
|
||||||
async execute(input: any) {
|
async execute(input: any) {
|
||||||
try {
|
try {
|
||||||
const result = await invokeComposioTool(logger, projectId, name, composioData, input);
|
const result = await invokeComposioTool(logger, usageTracker, projectId, name, composioData, input);
|
||||||
return JSON.stringify({
|
return JSON.stringify({
|
||||||
result,
|
result,
|
||||||
});
|
});
|
||||||
|
|
@ -479,6 +519,7 @@ export function createComposioTool(
|
||||||
|
|
||||||
export function createTools(
|
export function createTools(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
workflow: { tools: z.infer<typeof WorkflowTool>[] },
|
workflow: { tools: z.infer<typeof WorkflowTool>[] },
|
||||||
toolConfig: Record<string, z.infer<typeof WorkflowTool>>,
|
toolConfig: Record<string, z.infer<typeof WorkflowTool>>,
|
||||||
|
|
@ -492,16 +533,16 @@ export function createTools(
|
||||||
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
|
toolLogger.log(`creating tool: ${toolName} (type: ${config.mockTool ? 'mock' : config.isMcp ? 'mcp' : config.isComposio ? 'composio' : 'webhook'})`);
|
||||||
|
|
||||||
if (config.mockTool) {
|
if (config.mockTool) {
|
||||||
tools[toolName] = createMockTool(logger, config);
|
tools[toolName] = createMockTool(logger, usageTracker, config);
|
||||||
toolLogger.log(`✓ created mock tool: ${toolName}`);
|
toolLogger.log(`✓ created mock tool: ${toolName}`);
|
||||||
} else if (config.isMcp) {
|
} else if (config.isMcp) {
|
||||||
tools[toolName] = createMcpTool(logger, config, projectId);
|
tools[toolName] = createMcpTool(logger, usageTracker, config, projectId);
|
||||||
toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`);
|
toolLogger.log(`✓ created mcp tool: ${toolName} (server: ${config.mcpServerName || 'unknown'})`);
|
||||||
} else if (config.isComposio) {
|
} else if (config.isComposio) {
|
||||||
tools[toolName] = createComposioTool(logger, config, projectId);
|
tools[toolName] = createComposioTool(logger, usageTracker, config, projectId);
|
||||||
toolLogger.log(`✓ created composio tool: ${toolName}`);
|
toolLogger.log(`✓ created composio tool: ${toolName}`);
|
||||||
} else {
|
} else {
|
||||||
tools[toolName] = createWebhookTool(logger, config, projectId);
|
tools[toolName] = createWebhookTool(logger, usageTracker, config, projectId);
|
||||||
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
|
toolLogger.log(`✓ created webhook tool: ${toolName} (fallback)`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ import { ConnectedEntity, sanitizeTextWithMentions, Workflow, WorkflowAgent, Wor
|
||||||
import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions";
|
import { CHILD_TRANSFER_RELATED_INSTRUCTIONS, CONVERSATION_TYPE_INSTRUCTIONS, PIPELINE_TYPE_INSTRUCTIONS, RAG_INSTRUCTIONS, TASK_TYPE_INSTRUCTIONS } from "./agent_instructions";
|
||||||
import { PrefixLogger } from "./utils";
|
import { PrefixLogger } from "./utils";
|
||||||
import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types";
|
import { Message, AssistantMessage, AssistantMessageWithToolCalls, ToolMessage } from "./types/types";
|
||||||
|
import { UsageTracker } from "./billing";
|
||||||
|
|
||||||
// Native handoff support
|
// Native handoff support
|
||||||
import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs";
|
import { createAgentHandoff, getSchemaForAgent, createContextFilterForAgent } from "./agent-handoffs";
|
||||||
import { PipelineStateManager } from "./pipeline-state-manager";
|
import { PipelineStateManager } from "./pipeline-state-manager";
|
||||||
|
|
@ -78,14 +80,6 @@ const openai = createOpenAI({
|
||||||
baseURL: PROVIDER_BASE_URL,
|
baseURL: PROVIDER_BASE_URL,
|
||||||
});
|
});
|
||||||
|
|
||||||
const ZUsage = z.object({
|
|
||||||
tokens: z.object({
|
|
||||||
total: z.number(),
|
|
||||||
prompt: z.number(),
|
|
||||||
completion: z.number(),
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const ZOutMessage = z.union([
|
const ZOutMessage = z.union([
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
AssistantMessageWithToolCalls,
|
AssistantMessageWithToolCalls,
|
||||||
|
|
@ -95,6 +89,7 @@ const ZOutMessage = z.union([
|
||||||
// Helper to create an agent
|
// Helper to create an agent
|
||||||
function createAgent(
|
function createAgent(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
config: z.infer<typeof WorkflowAgent>,
|
config: z.infer<typeof WorkflowAgent>,
|
||||||
tools: Record<string, Tool>,
|
tools: Record<string, Tool>,
|
||||||
|
|
@ -145,7 +140,7 @@ ${CHILD_TRANSFER_RELATED_INSTRUCTIONS}
|
||||||
|
|
||||||
// Add RAG tool if needed
|
// Add RAG tool if needed
|
||||||
if (config.ragDataSources?.length) {
|
if (config.ragDataSources?.length) {
|
||||||
const ragTool = createRagTool(logger, config, projectId);
|
const ragTool = createRagTool(logger, usageTracker, config, projectId);
|
||||||
agentTools.push(ragTool);
|
agentTools.push(ragTool);
|
||||||
|
|
||||||
// update instructions to include RAG instructions
|
// update instructions to include RAG instructions
|
||||||
|
|
@ -269,8 +264,8 @@ function getStartOfTurnAgentName(
|
||||||
// Logs an event and then yields it
|
// Logs an event and then yields it
|
||||||
async function* emitEvent(
|
async function* emitEvent(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
event: z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>,
|
event: z.infer<typeof ZOutMessage>,
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||||
logger.log(`-> emitting event: ${JSON.stringify(event)}`);
|
logger.log(`-> emitting event: ${JSON.stringify(event)}`);
|
||||||
yield event;
|
yield event;
|
||||||
return;
|
return;
|
||||||
|
|
@ -321,30 +316,6 @@ class AgentTransferCounter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class UsageTracker {
|
|
||||||
private usage: {
|
|
||||||
total: number;
|
|
||||||
prompt: number;
|
|
||||||
completion: number;
|
|
||||||
} = { total: 0, prompt: 0, completion: 0 };
|
|
||||||
|
|
||||||
increment(total: number, prompt: number, completion: number): void {
|
|
||||||
this.usage.total += total;
|
|
||||||
this.usage.prompt += prompt;
|
|
||||||
this.usage.completion += completion;
|
|
||||||
}
|
|
||||||
|
|
||||||
get(): { total: number, prompt: number, completion: number } {
|
|
||||||
return this.usage;
|
|
||||||
}
|
|
||||||
|
|
||||||
asEvent(): z.infer<typeof ZUsage> {
|
|
||||||
return {
|
|
||||||
tokens: this.usage,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) {
|
function ensureSystemMessage(logger: PrefixLogger, messages: z.infer<typeof Message>[]) {
|
||||||
logger = logger.child(`ensureSystemMessage`);
|
logger = logger.child(`ensureSystemMessage`);
|
||||||
|
|
||||||
|
|
@ -396,7 +367,7 @@ function mapConfig(workflow: z.infer<typeof Workflow>): {
|
||||||
return { agentConfig, toolConfig, promptConfig, pipelineConfig };
|
return { agentConfig, toolConfig, promptConfig, pipelineConfig };
|
||||||
}
|
}
|
||||||
|
|
||||||
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof Workflow>): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||||
// find the greeting prompt
|
// find the greeting prompt
|
||||||
const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?';
|
const prompt = workflow.prompts.find(p => p.type === 'greeting')?.prompt || 'How can I help you today?';
|
||||||
logger.log(`greeting turn: ${prompt}`);
|
logger.log(`greeting turn: ${prompt}`);
|
||||||
|
|
@ -408,15 +379,13 @@ async function* emitGreetingTurn(logger: PrefixLogger, workflow: z.infer<typeof
|
||||||
agentName: workflow.startAgent,
|
agentName: workflow.startAgent,
|
||||||
responseType: 'external',
|
responseType: 'external',
|
||||||
});
|
});
|
||||||
|
|
||||||
// emit final usage information
|
|
||||||
yield* emitEvent(logger, new UsageTracker().asEvent());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Enhanced agent creation with native handoff support
|
// Enhanced agent creation with native handoff support
|
||||||
function createAgentsWithNativeHandoffs(
|
function createAgentsWithNativeHandoffs(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
workflow: z.infer<typeof Workflow>,
|
workflow: z.infer<typeof Workflow>,
|
||||||
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||||
|
|
@ -447,6 +416,7 @@ function createAgentsWithNativeHandoffs(
|
||||||
|
|
||||||
const { agent, entities } = createAgent(
|
const { agent, entities } = createAgent(
|
||||||
logger,
|
logger,
|
||||||
|
usageTracker,
|
||||||
projectId,
|
projectId,
|
||||||
config,
|
config,
|
||||||
tools,
|
tools,
|
||||||
|
|
@ -560,6 +530,7 @@ function createAgentsWithNativeHandoffs(
|
||||||
// Legacy agent creation (existing implementation)
|
// Legacy agent creation (existing implementation)
|
||||||
function createAgentsLegacy(
|
function createAgentsLegacy(
|
||||||
logger: PrefixLogger,
|
logger: PrefixLogger,
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
workflow: z.infer<typeof Workflow>,
|
workflow: z.infer<typeof Workflow>,
|
||||||
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||||
|
|
@ -595,6 +566,7 @@ function createAgentsLegacy(
|
||||||
|
|
||||||
const { agent, entities } = createAgent(
|
const { agent, entities } = createAgent(
|
||||||
logger,
|
logger,
|
||||||
|
usageTracker,
|
||||||
projectId,
|
projectId,
|
||||||
config,
|
config,
|
||||||
tools,
|
tools,
|
||||||
|
|
@ -763,12 +735,13 @@ function maybeInjectGiveUpControlInstructions(
|
||||||
// Handle raw model stream events
|
// Handle raw model stream events
|
||||||
async function* handleRawModelStreamEvent(
|
async function* handleRawModelStreamEvent(
|
||||||
event: any,
|
event: any,
|
||||||
|
agentConfig: Record<string, z.infer<typeof WorkflowAgent>>,
|
||||||
agentName: string,
|
agentName: string,
|
||||||
turnMsgs: z.infer<typeof Message>[],
|
turnMsgs: z.infer<typeof Message>[],
|
||||||
usageTracker: UsageTracker,
|
usageTracker: UsageTracker,
|
||||||
eventLogger: PrefixLogger,
|
eventLogger: PrefixLogger,
|
||||||
getAgentState?: (agentName: string) => AgentState
|
getAgentState?: (agentName: string) => AgentState
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||||
if (event.data.type === 'response_done') {
|
if (event.data.type === 'response_done') {
|
||||||
// Count tool calls (excluding transfer_to_* calls)
|
// Count tool calls (excluding transfer_to_* calls)
|
||||||
const toolCallCount = event.data.response.output.filter(
|
const toolCallCount = event.data.response.output.filter(
|
||||||
|
|
@ -809,12 +782,13 @@ async function* handleRawModelStreamEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// update usage information
|
// update usage information
|
||||||
usageTracker.increment(
|
usageTracker.track({
|
||||||
event.data.response.usage.totalTokens,
|
type: "LLM_USAGE",
|
||||||
event.data.response.usage.inputTokens,
|
modelName: agentConfig[agentName]?.model || "unknown",
|
||||||
event.data.response.usage.outputTokens
|
inputTokens: event.data.response.usage.inputTokens,
|
||||||
);
|
outputTokens: event.data.response.usage.outputTokens,
|
||||||
eventLogger.log(`updated usage information: ${JSON.stringify(usageTracker.get())}`);
|
context: "agents_runtime.llm_usage",
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -833,7 +807,7 @@ async function* handleNativeHandoffEvent(
|
||||||
originalHandoffs: Record<string, any[]>,
|
originalHandoffs: Record<string, any[]>,
|
||||||
eventLogger: PrefixLogger,
|
eventLogger: PrefixLogger,
|
||||||
loopLogger: PrefixLogger
|
loopLogger: PrefixLogger
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string; shouldContinue?: boolean }> {
|
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string; shouldContinue?: boolean }> {
|
||||||
eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
eventLogger.log(`🔄 NATIVE HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
||||||
|
|
||||||
// skip if its the same agent
|
// skip if its the same agent
|
||||||
|
|
@ -943,7 +917,7 @@ async function* handleHandoffEvent(
|
||||||
originalHandoffs: Record<string, Agent[]>,
|
originalHandoffs: Record<string, Agent[]>,
|
||||||
eventLogger: PrefixLogger,
|
eventLogger: PrefixLogger,
|
||||||
loopLogger: PrefixLogger
|
loopLogger: PrefixLogger
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string }> {
|
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string }> {
|
||||||
eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
eventLogger.log(`🔄 HANDOFF EVENT: ${agentName} -> ${event.item.targetAgent.name}`);
|
||||||
|
|
||||||
// skip if its the same agent
|
// skip if its the same agent
|
||||||
|
|
@ -1010,7 +984,7 @@ async function* handleToolCallResult(
|
||||||
event: any,
|
event: any,
|
||||||
turnMsgs: z.infer<typeof Message>[],
|
turnMsgs: z.infer<typeof Message>[],
|
||||||
eventLogger: PrefixLogger
|
eventLogger: PrefixLogger
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||||
const m: z.infer<typeof Message> = {
|
const m: z.infer<typeof Message> = {
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
content: event.item.rawItem.output.text,
|
content: event.item.rawItem.output.text,
|
||||||
|
|
@ -1039,7 +1013,7 @@ async function* handleMessageOutput(
|
||||||
eventLogger: PrefixLogger,
|
eventLogger: PrefixLogger,
|
||||||
loopLogger: PrefixLogger,
|
loopLogger: PrefixLogger,
|
||||||
getAgentState: (agentName: string) => AgentState
|
getAgentState: (agentName: string) => AgentState
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage> | { newAgentName: string | null; shouldContinue: boolean }> {
|
): AsyncIterable<z.infer<typeof ZOutMessage> | { newAgentName: string | null; shouldContinue: boolean }> {
|
||||||
// check response visibility - could be an agent or pipeline
|
// check response visibility - could be an agent or pipeline
|
||||||
const agentConfigObj = agentConfig[agentName];
|
const agentConfigObj = agentConfig[agentName];
|
||||||
const pipelineConfigObj = pipelineConfig[agentName];
|
const pipelineConfigObj = pipelineConfig[agentName];
|
||||||
|
|
@ -1243,7 +1217,8 @@ export async function* streamResponse(
|
||||||
projectId: string,
|
projectId: string,
|
||||||
workflow: z.infer<typeof Workflow>,
|
workflow: z.infer<typeof Workflow>,
|
||||||
messages: z.infer<typeof Message>[],
|
messages: z.infer<typeof Message>[],
|
||||||
): AsyncIterable<z.infer<typeof ZOutMessage> | z.infer<typeof ZUsage>> {
|
usageTracker: UsageTracker,
|
||||||
|
): AsyncIterable<z.infer<typeof ZOutMessage>> {
|
||||||
// Divider log for tracking agent loop start
|
// Divider log for tracking agent loop start
|
||||||
console.log('-------------------- AGENT LOOP START --------------------');
|
console.log('-------------------- AGENT LOOP START --------------------');
|
||||||
// set up logging
|
// set up logging
|
||||||
|
|
@ -1275,11 +1250,11 @@ export async function* streamResponse(
|
||||||
logger.log(`initialized stack: ${JSON.stringify(stack)}`);
|
logger.log(`initialized stack: ${JSON.stringify(stack)}`);
|
||||||
|
|
||||||
// create tools
|
// create tools
|
||||||
const tools = createTools(logger, projectId, workflow, toolConfig);
|
const tools = createTools(logger, usageTracker, projectId, workflow, toolConfig);
|
||||||
|
|
||||||
// create agents with feature flag support
|
// create agents with feature flag support
|
||||||
const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy;
|
const createAgentsFunction = USE_NATIVE_HANDOFFS ? createAgentsWithNativeHandoffs : createAgentsLegacy;
|
||||||
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
|
const { agents, originalInstructions, originalHandoffs } = createAgentsFunction(logger, usageTracker, projectId, workflow, agentConfig, tools, promptConfig, pipelineConfig);
|
||||||
|
|
||||||
logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`);
|
logger.log(`Using ${USE_NATIVE_HANDOFFS ? 'NATIVE SDK' : 'LEGACY'} handoffs`);
|
||||||
|
|
||||||
|
|
@ -1296,7 +1271,6 @@ export async function* streamResponse(
|
||||||
let agentName: string | null = startOfTurnAgentName;
|
let agentName: string | null = startOfTurnAgentName;
|
||||||
|
|
||||||
// start the turn loop
|
// start the turn loop
|
||||||
const usageTracker = new UsageTracker();
|
|
||||||
const turnMsgs: z.infer<typeof Message>[] = [...messages];
|
const turnMsgs: z.infer<typeof Message>[] = [...messages];
|
||||||
|
|
||||||
// Initialize agent state tracking for tool call completion
|
// Initialize agent state tracking for tool call completion
|
||||||
|
|
@ -1390,7 +1364,7 @@ export async function* streamResponse(
|
||||||
|
|
||||||
switch (event.type) {
|
switch (event.type) {
|
||||||
case 'raw_model_stream_event':
|
case 'raw_model_stream_event':
|
||||||
yield* handleRawModelStreamEvent(event, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
|
yield* handleRawModelStreamEvent(event, agentConfig, agentName!, turnMsgs, usageTracker, eventLogger, getAgentState);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'run_item_stream_event':
|
case 'run_item_stream_event':
|
||||||
|
|
@ -1523,9 +1497,6 @@ export async function* streamResponse(
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// emit usage information
|
|
||||||
yield* emitEvent(logger, usageTracker.asEvent());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is a sync version of streamResponse
|
// this is a sync version of streamResponse
|
||||||
|
|
@ -1535,8 +1506,10 @@ export async function getResponse(
|
||||||
messages: z.infer<typeof Message>[],
|
messages: z.infer<typeof Message>[],
|
||||||
): Promise<{
|
): Promise<{
|
||||||
messages: z.infer<typeof ZOutMessage>[],
|
messages: z.infer<typeof ZOutMessage>[],
|
||||||
usage: z.infer<typeof ZUsage>,
|
usage: any,
|
||||||
}> {
|
}> {
|
||||||
|
throw new Error("Not implemented!");
|
||||||
|
/*
|
||||||
const out: z.infer<typeof ZOutMessage>[] = [];
|
const out: z.infer<typeof ZOutMessage>[] = [];
|
||||||
let usage: z.infer<typeof ZUsage> = {
|
let usage: z.infer<typeof ZUsage> = {
|
||||||
tokens: {
|
tokens: {
|
||||||
|
|
@ -1554,4 +1527,5 @@ export async function getResponse(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return { messages: out, usage };
|
return { messages: out, usage };
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import { WithStringId } from './types/types';
|
import { WithStringId } from './types/types';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse } from './types/billing_types';
|
import { Customer, AuthorizeRequest, AuthorizeResponse, LogUsageRequest, UsageResponse, CustomerPortalSessionResponse, PricesResponse, UpdateSubscriptionPlanRequest, UpdateSubscriptionPlanResponse, ModelsResponse, UsageItem } from './types/billing_types';
|
||||||
import { ObjectId } from 'mongodb';
|
import { ObjectId } from 'mongodb';
|
||||||
import { projectsCollection, usersCollection } from './mongodb';
|
import { projectsCollection, usersCollection } from './mongodb';
|
||||||
import { redirect } from 'next/navigation';
|
import { redirect } from 'next/navigation';
|
||||||
|
|
@ -23,6 +23,20 @@ const GUEST_BILLING_CUSTOMER = {
|
||||||
updatedAt: new Date().toISOString(),
|
updatedAt: new Date().toISOString(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export class UsageTracker{
|
||||||
|
private items: z.infer<typeof UsageItem>[] = [];
|
||||||
|
|
||||||
|
track(item: z.infer<typeof UsageItem>) {
|
||||||
|
this.items.push(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
flush(): z.infer<typeof UsageItem>[] {
|
||||||
|
const items = this.items;
|
||||||
|
this.items = [];
|
||||||
|
return items;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function getCustomerIdForProject(projectId: string): Promise<string> {
|
export async function getCustomerIdForProject(projectId: string): Promise<string> {
|
||||||
const project = await projectsCollection.findOne({ _id: projectId });
|
const project = await projectsCollection.findOne({ _id: projectId });
|
||||||
if (!project) {
|
if (!project) {
|
||||||
|
|
@ -111,6 +125,7 @@ export async function authorize(customerId: string, request: z.infer<typeof Auth
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) {
|
export async function logUsage(customerId: string, request: z.infer<typeof LogUsageRequest>) {
|
||||||
|
console.log(`logging billing usage for customer ${customerId}`, JSON.stringify(request));
|
||||||
const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, {
|
const response = await fetch(`${BILLING_API_URL}/api/customers/${customerId}/log-usage`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import { COPILOT_MULTI_AGENT_EXAMPLE_1 } from "./example_multi_agent_1";
|
||||||
import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow";
|
import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow";
|
||||||
import { USE_COMPOSIO_TOOLS } from "../feature_flags";
|
import { USE_COMPOSIO_TOOLS } from "../feature_flags";
|
||||||
import { composio, getTool } from "../composio/composio";
|
import { composio, getTool } from "../composio/composio";
|
||||||
|
import { UsageTracker } from "../billing";
|
||||||
|
|
||||||
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||||
const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined;
|
const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined;
|
||||||
|
|
@ -119,7 +120,7 @@ ${JSON.stringify(simplifiedDataSources)}
|
||||||
return prompt;
|
return prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function searchRelevantTools(query: string): Promise<string> {
|
async function searchRelevantTools(usageTracker: UsageTracker, query: string): Promise<string> {
|
||||||
const logger = new PrefixLogger("copilot-search-tools");
|
const logger = new PrefixLogger("copilot-search-tools");
|
||||||
console.log("🔧 TOOL CALL: searchRelevantTools", { query });
|
console.log("🔧 TOOL CALL: searchRelevantTools", { query });
|
||||||
|
|
||||||
|
|
@ -142,6 +143,13 @@ async function searchRelevantTools(query: string): Promise<string> {
|
||||||
return 'No tools found!';
|
return 'No tools found!';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track composio search tool usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "COMPOSIO_TOOL_USAGE",
|
||||||
|
toolSlug: "COMPOSIO_SEARCH_TOOLS",
|
||||||
|
context: "copilot.search_relevant_tools",
|
||||||
|
});
|
||||||
|
|
||||||
// parse results
|
// parse results
|
||||||
const result = composioToolSearchResponseSchema.safeParse(searchResult.data);
|
const result = composioToolSearchResponseSchema.safeParse(searchResult.data);
|
||||||
if (!result.success) {
|
if (!result.success) {
|
||||||
|
|
@ -208,6 +216,7 @@ function updateLastUserMessage(
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getEditAgentInstructionsResponse(
|
export async function getEditAgentInstructionsResponse(
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
context: z.infer<typeof CopilotChatContext> | null,
|
context: z.infer<typeof CopilotChatContext> | null,
|
||||||
messages: z.infer<typeof CopilotMessage>[],
|
messages: z.infer<typeof CopilotMessage>[],
|
||||||
|
|
@ -232,7 +241,7 @@ export async function getEditAgentInstructionsResponse(
|
||||||
system: COPILOT_INSTRUCTIONS_EDIT_AGENT,
|
system: COPILOT_INSTRUCTIONS_EDIT_AGENT,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
}));
|
}));
|
||||||
const { object } = await generateObject({
|
const { object, usage } = await generateObject({
|
||||||
model: openai(COPILOT_MODEL),
|
model: openai(COPILOT_MODEL),
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
|
|
@ -246,10 +255,20 @@ export async function getEditAgentInstructionsResponse(
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// log usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "LLM_USAGE",
|
||||||
|
modelName: COPILOT_MODEL,
|
||||||
|
inputTokens: usage.promptTokens,
|
||||||
|
outputTokens: usage.completionTokens,
|
||||||
|
context: "copilot.llm_usage",
|
||||||
|
});
|
||||||
|
|
||||||
return object.agent_instructions;
|
return object.agent_instructions;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function* streamMultiAgentResponse(
|
export async function* streamMultiAgentResponse(
|
||||||
|
usageTracker: UsageTracker,
|
||||||
projectId: string,
|
projectId: string,
|
||||||
context: z.infer<typeof CopilotChatContext> | null,
|
context: z.infer<typeof CopilotChatContext> | null,
|
||||||
messages: z.infer<typeof CopilotMessage>[],
|
messages: z.infer<typeof CopilotMessage>[],
|
||||||
|
|
@ -297,7 +316,7 @@ export async function* streamMultiAgentResponse(
|
||||||
}),
|
}),
|
||||||
execute: async ({ query }: { query: string }) => {
|
execute: async ({ query }: { query: string }) => {
|
||||||
console.log("🎯 AI TOOL CALL: search_relevant_tools", { query });
|
console.log("🎯 AI TOOL CALL: search_relevant_tools", { query });
|
||||||
const result = await searchRelevantTools(query);
|
const result = await searchRelevantTools(usageTracker, query);
|
||||||
console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", {
|
console.log("✅ AI TOOL CALL COMPLETED: search_relevant_tools", {
|
||||||
query,
|
query,
|
||||||
resultLength: result.length
|
resultLength: result.length
|
||||||
|
|
@ -341,6 +360,15 @@ export async function* streamMultiAgentResponse(
|
||||||
toolCallId: event.toolCallId,
|
toolCallId: event.toolCallId,
|
||||||
result: event.result,
|
result: event.result,
|
||||||
};
|
};
|
||||||
|
} else if (event.type === "step-finish") {
|
||||||
|
// log usage
|
||||||
|
usageTracker.track({
|
||||||
|
type: "LLM_USAGE",
|
||||||
|
modelName: COPILOT_MODEL,
|
||||||
|
inputTokens: event.usage.promptTokens,
|
||||||
|
outputTokens: event.usage.completionTokens,
|
||||||
|
context: "copilot.llm_usage",
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,52 @@ import { z } from "zod";
|
||||||
|
|
||||||
export const SubscriptionPlan = z.enum(["free", "starter", "pro"]);
|
export const SubscriptionPlan = z.enum(["free", "starter", "pro"]);
|
||||||
|
|
||||||
export const UsageType = z.enum([
|
export const UsageTypeKey = z.enum([
|
||||||
"copilot_requests",
|
"LLM_USAGE",
|
||||||
"agent_messages",
|
"EMBEDDING_MODEL_USAGE",
|
||||||
"rag_tokens",
|
"COMPOSIO_TOOL_USAGE",
|
||||||
|
"FIRECRAWL_SCRAPE_USAGE",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
export const LLMUsage = z.object({
|
||||||
|
type: z.literal(UsageTypeKey.Enum.LLM_USAGE),
|
||||||
|
modelName: z.string(),
|
||||||
|
inputTokens: z.number().positive(),
|
||||||
|
outputTokens: z.number().positive(),
|
||||||
|
context: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const EmbeddingModelUsage = z.object({
|
||||||
|
type: z.literal(UsageTypeKey.Enum.EMBEDDING_MODEL_USAGE),
|
||||||
|
modelName: z.string(),
|
||||||
|
tokens: z.number().positive(),
|
||||||
|
context: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const ComposioToolUsage = z.object({
|
||||||
|
type: z.literal(UsageTypeKey.Enum.COMPOSIO_TOOL_USAGE),
|
||||||
|
toolSlug: z.string(),
|
||||||
|
context: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const FirecrawlScrapeUsage = z.object({
|
||||||
|
type: z.literal(UsageTypeKey.Enum.FIRECRAWL_SCRAPE_USAGE),
|
||||||
|
context: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const UsageItem = z.discriminatedUnion("type", [
|
||||||
|
LLMUsage,
|
||||||
|
EmbeddingModelUsage,
|
||||||
|
ComposioToolUsage,
|
||||||
|
FirecrawlScrapeUsage,
|
||||||
|
]);
|
||||||
|
|
||||||
|
export const LogUsageRequest = z.object({
|
||||||
|
items: z.array(UsageItem),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const CustomerUsageData = z.record(z.string(), z.number());
|
||||||
|
|
||||||
export const Customer = z.object({
|
export const Customer = z.object({
|
||||||
_id: z.string(),
|
_id: z.string(),
|
||||||
userId: z.string(),
|
userId: z.string(),
|
||||||
|
|
@ -19,36 +59,23 @@ export const Customer = z.object({
|
||||||
createdAt: z.string().datetime(),
|
createdAt: z.string().datetime(),
|
||||||
updatedAt: z.string().datetime(),
|
updatedAt: z.string().datetime(),
|
||||||
subscriptionPlanUpdatedAt: z.string().datetime().optional(),
|
subscriptionPlanUpdatedAt: z.string().datetime().optional(),
|
||||||
usage: z.record(UsageType, z.number()).optional(),
|
usage: CustomerUsageData.optional(),
|
||||||
usageUpdatedAt: z.string().datetime().optional(),
|
usageUpdatedAt: z.string().datetime().optional(),
|
||||||
});
|
creditsOverride: z.number().optional(),
|
||||||
|
maxProjectsOverride: z.number().optional(),
|
||||||
export const LogUsageRequest = z.object({
|
agentModelsOverride: z.array(z.string()).optional(),
|
||||||
type: UsageType,
|
|
||||||
amount: z.number().int().positive(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export const AuthorizeRequest = z.discriminatedUnion("type", [
|
export const AuthorizeRequest = z.discriminatedUnion("type", [
|
||||||
|
z.object({
|
||||||
|
"type": z.literal("use_credits"),
|
||||||
|
}),
|
||||||
z.object({
|
z.object({
|
||||||
"type": z.literal("create_project"),
|
"type": z.literal("create_project"),
|
||||||
"data": z.object({
|
"data": z.object({
|
||||||
"existingProjectCount": z.number(),
|
"existingProjectCount": z.number(),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
z.object({
|
|
||||||
"type": z.literal("enable_hosted_tool_server"),
|
|
||||||
"data": z.object({
|
|
||||||
"existingServerCount": z.number(),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
z.object({
|
|
||||||
"type": z.literal("process_rag"),
|
|
||||||
"data": z.object({}),
|
|
||||||
}),
|
|
||||||
z.object({
|
|
||||||
"type": z.literal("copilot_request"),
|
|
||||||
"data": z.object({}),
|
|
||||||
}),
|
|
||||||
z.object({
|
z.object({
|
||||||
"type": z.literal("agent_response"),
|
"type": z.literal("agent_response"),
|
||||||
"data": z.object({
|
"data": z.object({
|
||||||
|
|
@ -63,10 +90,9 @@ export const AuthorizeResponse = z.object({
|
||||||
});
|
});
|
||||||
|
|
||||||
export const UsageResponse = z.object({
|
export const UsageResponse = z.object({
|
||||||
usage: z.record(UsageType, z.object({
|
sanctionedCredits: z.number(),
|
||||||
usage: z.number(),
|
availableCredits: z.number(),
|
||||||
total: z.number(),
|
usage: CustomerUsageData,
|
||||||
})),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export const CustomerPortalSessionRequest = z.object({
|
export const CustomerPortalSessionRequest = z.object({
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import crypto from 'crypto';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import { createOpenAI } from '@ai-sdk/openai';
|
import { createOpenAI } from '@ai-sdk/openai';
|
||||||
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
|
import { USE_BILLING, USE_GEMINI_FILE_PARSING } from '../lib/feature_flags';
|
||||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||||
import { BillingError } from '@/src/entities/errors/common';
|
import { BillingError } from '@/src/entities/errors/common';
|
||||||
|
|
||||||
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
const FILE_PARSING_PROVIDER_API_KEY = process.env.FILE_PARSING_PROVIDER_API_KEY || process.env.OPENAI_API_KEY || '';
|
||||||
|
|
@ -75,7 +75,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }): Promise<number> {
|
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } }) {
|
||||||
const logger = _logger
|
const logger = _logger
|
||||||
.child(doc._id.toString())
|
.child(doc._id.toString())
|
||||||
.child(doc.name);
|
.child(doc.name);
|
||||||
|
|
@ -95,7 +95,7 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
if (!USE_GEMINI_FILE_PARSING) {
|
if (!USE_GEMINI_FILE_PARSING) {
|
||||||
// Use OpenAI to extract text content
|
// Use OpenAI to extract text content
|
||||||
logger.log("Extracting content using OpenAI");
|
logger.log("Extracting content using OpenAI");
|
||||||
const { text } = await generateText({
|
const { text, usage } = await generateText({
|
||||||
model: openai(FILE_PARSING_MODEL),
|
model: openai(FILE_PARSING_MODEL),
|
||||||
system: extractPrompt,
|
system: extractPrompt,
|
||||||
messages: [
|
messages: [
|
||||||
|
|
@ -112,6 +112,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
markdown = text;
|
markdown = text;
|
||||||
|
usageTracker.track({
|
||||||
|
type: "LLM_USAGE",
|
||||||
|
modelName: FILE_PARSING_MODEL,
|
||||||
|
inputTokens: usage.promptTokens,
|
||||||
|
outputTokens: usage.completionTokens,
|
||||||
|
context: "rag.files.llm_usage",
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
// Use Gemini to extract text content
|
// Use Gemini to extract text content
|
||||||
logger.log("Extracting content using Gemini");
|
logger.log("Extracting content using Gemini");
|
||||||
|
|
@ -127,6 +134,13 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
extractPrompt,
|
extractPrompt,
|
||||||
]);
|
]);
|
||||||
markdown = result.response.text();
|
markdown = result.response.text();
|
||||||
|
usageTracker.track({
|
||||||
|
type: "LLM_USAGE",
|
||||||
|
modelName: FILE_PARSING_MODEL,
|
||||||
|
inputTokens: result.response.usageMetadata?.promptTokenCount || 0,
|
||||||
|
outputTokens: result.response.usageMetadata?.candidatesTokenCount || 0,
|
||||||
|
context: "rag.files.llm_usage",
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// split into chunks
|
// split into chunks
|
||||||
|
|
@ -139,6 +153,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
model: embeddingModel,
|
model: embeddingModel,
|
||||||
values: splits.map((split) => split.pageContent)
|
values: splits.map((split) => split.pageContent)
|
||||||
});
|
});
|
||||||
|
usageTracker.track({
|
||||||
|
type: "EMBEDDING_MODEL_USAGE",
|
||||||
|
modelName: embeddingModel.modelId,
|
||||||
|
tokens: usage.tokens,
|
||||||
|
context: "rag.files.embedding_usage",
|
||||||
|
});
|
||||||
|
|
||||||
// store embeddings in qdrant
|
// store embeddings in qdrant
|
||||||
logger.log("Storing embeddings in Qdrant");
|
logger.log("Storing embeddings in Qdrant");
|
||||||
|
|
@ -170,8 +190,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
lastUpdatedAt: new Date().toISOString(),
|
lastUpdatedAt: new Date().toISOString(),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return usage.tokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||||
|
|
@ -339,8 +357,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
// authorize with billing
|
// authorize with billing
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
const authResponse = await authorize(billingCustomerId, {
|
const authResponse = await authorize(billingCustomerId, {
|
||||||
type: "process_rag",
|
type: "use_credits",
|
||||||
data: {},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if ('error' in authResponse) {
|
if ('error' in authResponse) {
|
||||||
|
|
@ -349,16 +366,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
}
|
}
|
||||||
|
|
||||||
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
|
const ldoc = doc as WithId<z.infer<typeof DataSourceDoc>> & { data: { type: "file_local" | "file_s3" } };
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
try {
|
try {
|
||||||
const usedTokens = await runProcessPipeline(logger, job, ldoc);
|
await runProcessPipeline(logger, usageTracker, job, ldoc);
|
||||||
|
|
||||||
// log usage in billing
|
|
||||||
if (USE_BILLING && billingCustomerId) {
|
|
||||||
await logUsage(billingCustomerId, {
|
|
||||||
type: "rag_tokens",
|
|
||||||
amount: usedTokens,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
errors = true;
|
errors = true;
|
||||||
logger.log("Error processing doc:", e);
|
logger.log("Error processing doc:", e);
|
||||||
|
|
@ -371,6 +381,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
error: e.message,
|
error: e.message,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
} finally {
|
||||||
|
// log usage in billing
|
||||||
|
if (USE_BILLING && billingCustomerId) {
|
||||||
|
await logUsage(billingCustomerId, {
|
||||||
|
items: usageTracker.flush(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import { qdrantClient } from '../lib/qdrant';
|
||||||
import { PrefixLogger } from "../lib/utils";
|
import { PrefixLogger } from "../lib/utils";
|
||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
import { USE_BILLING } from '../lib/feature_flags';
|
import { USE_BILLING } from '../lib/feature_flags';
|
||||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||||
import { BillingError } from '@/src/entities/errors/common';
|
import { BillingError } from '@/src/entities/errors/common';
|
||||||
|
|
||||||
const splitter = new RecursiveCharacterTextSplitter({
|
const splitter = new RecursiveCharacterTextSplitter({
|
||||||
|
|
@ -23,7 +23,7 @@ const second = 1000;
|
||||||
const minute = 60 * second;
|
const minute = 60 * second;
|
||||||
const hour = 60 * minute;
|
const hour = 60 * minute;
|
||||||
|
|
||||||
async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
|
async function runProcessPipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
|
||||||
const logger = _logger
|
const logger = _logger
|
||||||
.child(doc._id.toString())
|
.child(doc._id.toString())
|
||||||
.child(doc.name);
|
.child(doc.name);
|
||||||
|
|
@ -42,6 +42,12 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
model: embeddingModel,
|
model: embeddingModel,
|
||||||
values: splits.map((split) => split.pageContent)
|
values: splits.map((split) => split.pageContent)
|
||||||
});
|
});
|
||||||
|
usageTracker.track({
|
||||||
|
type: "EMBEDDING_MODEL_USAGE",
|
||||||
|
modelName: embeddingModel.modelId,
|
||||||
|
tokens: usage.tokens,
|
||||||
|
context: "rag.text.embedding_usage",
|
||||||
|
});
|
||||||
|
|
||||||
// store embeddings in qdrant
|
// store embeddings in qdrant
|
||||||
logger.log("Storing embeddings in Qdrant");
|
logger.log("Storing embeddings in Qdrant");
|
||||||
|
|
@ -73,8 +79,6 @@ async function runProcessPipeline(_logger: PrefixLogger, job: WithId<z.infer<typ
|
||||||
lastUpdatedAt: new Date().toISOString(),
|
lastUpdatedAt: new Date().toISOString(),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return usage.tokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||||
|
|
@ -241,8 +245,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
// authorize with billing
|
// authorize with billing
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
const authResponse = await authorize(billingCustomerId, {
|
const authResponse = await authorize(billingCustomerId, {
|
||||||
type: "process_rag",
|
type: "use_credits",
|
||||||
data: {}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if ('error' in authResponse) {
|
if ('error' in authResponse) {
|
||||||
|
|
@ -250,16 +253,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
try {
|
try {
|
||||||
const usedTokens = await runProcessPipeline(logger, job, doc);
|
await runProcessPipeline(logger, usageTracker, job, doc);
|
||||||
|
|
||||||
// log usage in billing
|
|
||||||
if (USE_BILLING && billingCustomerId) {
|
|
||||||
await logUsage(billingCustomerId, {
|
|
||||||
type: "rag_tokens",
|
|
||||||
amount: usedTokens,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
errors = true;
|
errors = true;
|
||||||
logger.log("Error processing doc:", e);
|
logger.log("Error processing doc:", e);
|
||||||
|
|
@ -272,6 +268,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
error: e.message,
|
error: e.message,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
} finally {
|
||||||
|
// log usage in billing
|
||||||
|
if (USE_BILLING && billingCustomerId) {
|
||||||
|
await logUsage(billingCustomerId, {
|
||||||
|
items: usageTracker.flush(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import { qdrantClient } from '../lib/qdrant';
|
||||||
import { PrefixLogger } from "../lib/utils";
|
import { PrefixLogger } from "../lib/utils";
|
||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
import { USE_BILLING } from '../lib/feature_flags';
|
import { USE_BILLING } from '../lib/feature_flags';
|
||||||
import { authorize, getCustomerIdForProject, logUsage } from '../lib/billing';
|
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from '../lib/billing';
|
||||||
import { BillingError } from '@/src/entities/errors/common';
|
import { BillingError } from '@/src/entities/errors/common';
|
||||||
|
|
||||||
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });
|
const firecrawl = new FirecrawlApp({ apiKey: process.env.FIRECRAWL_API_KEY });
|
||||||
|
|
@ -41,7 +41,7 @@ async function retryable<T>(fn: () => Promise<T>, maxAttempts: number = 3): Prom
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<number> {
|
async function runScrapePipeline(_logger: PrefixLogger, usageTracker: UsageTracker, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>) {
|
||||||
const logger = _logger
|
const logger = _logger
|
||||||
.child(doc._id.toString())
|
.child(doc._id.toString())
|
||||||
.child(doc.name);
|
.child(doc.name);
|
||||||
|
|
@ -62,6 +62,10 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
||||||
}
|
}
|
||||||
return scrapeResult;
|
return scrapeResult;
|
||||||
}, 3); // Retry up to 3 times
|
}, 3); // Retry up to 3 times
|
||||||
|
usageTracker.track({
|
||||||
|
type: "FIRECRAWL_SCRAPE_USAGE",
|
||||||
|
context: "rag.urls.firecrawl_scrape",
|
||||||
|
});
|
||||||
|
|
||||||
// split into chunks
|
// split into chunks
|
||||||
logger.log("Splitting into chunks");
|
logger.log("Splitting into chunks");
|
||||||
|
|
@ -73,6 +77,12 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
||||||
model: embeddingModel,
|
model: embeddingModel,
|
||||||
values: splits.map((split) => split.pageContent)
|
values: splits.map((split) => split.pageContent)
|
||||||
});
|
});
|
||||||
|
usageTracker.track({
|
||||||
|
type: "EMBEDDING_MODEL_USAGE",
|
||||||
|
modelName: embeddingModel.modelId,
|
||||||
|
tokens: usage.tokens,
|
||||||
|
context: "rag.urls.embedding_usage",
|
||||||
|
});
|
||||||
|
|
||||||
// store embeddings in qdrant
|
// store embeddings in qdrant
|
||||||
logger.log("Storing embeddings in Qdrant");
|
logger.log("Storing embeddings in Qdrant");
|
||||||
|
|
@ -104,8 +114,6 @@ async function runScrapePipeline(_logger: PrefixLogger, job: WithId<z.infer<type
|
||||||
lastUpdatedAt: new Date().toISOString(),
|
lastUpdatedAt: new Date().toISOString(),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return usage.tokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<typeof DataSource>>, doc: WithId<z.infer<typeof DataSourceDoc>>): Promise<void> {
|
||||||
|
|
@ -273,8 +281,7 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
// authorize with billing
|
// authorize with billing
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
const authResponse = await authorize(billingCustomerId, {
|
const authResponse = await authorize(billingCustomerId, {
|
||||||
type: "process_rag",
|
type: "use_credits",
|
||||||
data: {}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if ('error' in authResponse) {
|
if ('error' in authResponse) {
|
||||||
|
|
@ -282,16 +289,9 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
try {
|
try {
|
||||||
const usedTokens = await runScrapePipeline(logger, job, doc);
|
await runScrapePipeline(logger, usageTracker, job, doc);
|
||||||
|
|
||||||
// log usage in billing
|
|
||||||
if (USE_BILLING && billingCustomerId) {
|
|
||||||
await logUsage(billingCustomerId, {
|
|
||||||
type: "rag_tokens",
|
|
||||||
amount: usedTokens,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (e: any) {
|
} catch (e: any) {
|
||||||
errors = true;
|
errors = true;
|
||||||
logger.log("Error processing doc:", e);
|
logger.log("Error processing doc:", e);
|
||||||
|
|
@ -304,6 +304,13 @@ async function runDeletionPipeline(_logger: PrefixLogger, job: WithId<z.infer<ty
|
||||||
error: e.message,
|
error: e.message,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
} finally {
|
||||||
|
// log usage in billing
|
||||||
|
if (USE_BILLING && billingCustomerId) {
|
||||||
|
await logUsage(billingCustomerId, {
|
||||||
|
items: usageTracker.flush(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn";
|
import { Reason, Turn, TurnEvent } from "@/src/entities/models/turn";
|
||||||
import { USE_BILLING } from "@/app/lib/feature_flags";
|
import { USE_BILLING } from "@/app/lib/feature_flags";
|
||||||
import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing";
|
import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing";
|
||||||
import { NotFoundError } from '@/src/entities/errors/common';
|
import { NotFoundError } from '@/src/entities/errors/common';
|
||||||
import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface";
|
import { IConversationsRepository } from "@/src/application/repositories/conversations.repository.interface";
|
||||||
import { streamResponse } from "@/app/lib/agents";
|
import { streamResponse } from "@/app/lib/agents";
|
||||||
|
|
@ -69,6 +69,21 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
||||||
if (USE_BILLING) {
|
if (USE_BILLING) {
|
||||||
// get billing customer id for project
|
// get billing customer id for project
|
||||||
billingCustomerId = await getCustomerIdForProject(projectId);
|
billingCustomerId = await getCustomerIdForProject(projectId);
|
||||||
|
|
||||||
|
// validate enough credits
|
||||||
|
const result = await authorize(billingCustomerId, {
|
||||||
|
type: "use_credits"
|
||||||
|
});
|
||||||
|
if (!result.success) {
|
||||||
|
yield {
|
||||||
|
type: "error",
|
||||||
|
error: result.error || 'Billing error',
|
||||||
|
isBillingError: true,
|
||||||
|
};
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate model usage
|
||||||
const agentModels = conversation.workflow.agents.reduce((acc, agent) => {
|
const agentModels = conversation.workflow.agents.reduce((acc, agent) => {
|
||||||
acc.push(agent.model);
|
acc.push(agent.model);
|
||||||
return acc;
|
return acc;
|
||||||
|
|
@ -111,9 +126,13 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
||||||
conversation.workflow.mockTools = data.input.mockTools;
|
conversation.workflow.mockTools = data.input.mockTools;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init usage tracker
|
||||||
|
const usageTracker = new UsageTracker();
|
||||||
|
|
||||||
// call agents runtime and handle generated messages
|
// call agents runtime and handle generated messages
|
||||||
|
try {
|
||||||
const outputMessages: z.infer<typeof Message>[] = [];
|
const outputMessages: z.infer<typeof Message>[] = [];
|
||||||
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages)) {
|
for await (const event of streamResponse(projectId, conversation.workflow, inputMessages, usageTracker)) {
|
||||||
// handle msg events
|
// handle msg events
|
||||||
if ("role" in event) {
|
if ("role" in event) {
|
||||||
// collect generated message
|
// collect generated message
|
||||||
|
|
@ -128,7 +147,9 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
||||||
type: "message",
|
type: "message",
|
||||||
data: msg,
|
data: msg,
|
||||||
};
|
};
|
||||||
} else {
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// save turn data
|
// save turn data
|
||||||
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
|
const turn = await this.conversationsRepository.addTurn(data.conversationId, {
|
||||||
reason: data.reason,
|
reason: data.reason,
|
||||||
|
|
@ -142,15 +163,13 @@ export class RunConversationTurnUseCase implements IRunConversationTurnUseCase {
|
||||||
turn,
|
turn,
|
||||||
conversationId,
|
conversationId,
|
||||||
}
|
}
|
||||||
}
|
} finally {
|
||||||
}
|
|
||||||
|
|
||||||
// Log billing usage
|
// Log billing usage
|
||||||
if (USE_BILLING && billingCustomerId) {
|
if (USE_BILLING && billingCustomerId) {
|
||||||
await logUsage(billingCustomerId, {
|
await logUsage(billingCustomerId, {
|
||||||
type: "agent_messages",
|
items: usageTracker.flush(),
|
||||||
amount: outputMessages.length,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue