diff --git a/apps/rowboat/app/actions/copilot.actions.ts b/apps/rowboat/app/actions/copilot.actions.ts index a36b3c41..de9b1bae 100644 --- a/apps/rowboat/app/actions/copilot.actions.ts +++ b/apps/rowboat/app/actions/copilot.actions.ts @@ -1,5 +1,5 @@ 'use server'; -import { +import { CopilotAPIRequest, CopilotChatContext, CopilotMessage, DataSourceSchemaForCopilot, @@ -8,15 +8,18 @@ import { Workflow} from "../lib/types/workflow_types"; import { z } from 'zod'; import { projectAuthCheck } from "./project.actions"; -import { redisClient } from "../lib/redis"; import { authorizeUserAction, logUsage } from "./billing.actions"; import { USE_BILLING } from "../lib/feature_flags"; import { getEditAgentInstructionsResponse } from "../../src/application/lib/copilot/copilot"; import { container } from "@/di/container"; import { IUsageQuotaPolicy } from "@/src/application/policies/usage-quota.policy.interface"; import { UsageTracker } from "../lib/billing"; +import { authCheck } from "./auth.actions"; +import { ICreateCopilotCachedTurnController } from "@/src/interface-adapters/controllers/copilot/create-copilot-cached-turn.controller"; +import { BillingError } from "@/src/entities/errors/common"; const usageQuotaPolicy = container.resolve('usageQuotaPolicy'); +const createCopilotCachedTurnController = container.resolve('createCopilotCachedTurnController'); export async function getCopilotResponseStream( projectId: string, @@ -27,40 +30,29 @@ export async function getCopilotResponseStream( ): Promise<{ streamId: string; } | { billingError: string }> { - await projectAuthCheck(projectId); - await usageQuotaPolicy.assertAndConsumeProjectAction(projectId); + const user = await authCheck(); - // Check billing authorization - const authResponse = await authorizeUserAction({ - type: 'use_credits', - }); - if (!authResponse.success) { - return { billingError: authResponse.error || 'Billing error' }; + try { + const { key } = await createCopilotCachedTurnController.execute({ + caller: 'user', + userId: user.id, + data: { + projectId, + messages, + workflow: current_workflow_config, + context, + dataSources, + } + }); + return { + streamId: key, + }; + } catch (err) { + if (err instanceof BillingError) { + return { billingError: err.message }; + } + throw err; } - - await usageQuotaPolicy.assertAndConsumeProjectAction(projectId); - - // prepare request - const request: z.infer = { - projectId, - messages, - workflow: current_workflow_config, - context, - dataSources: dataSources, - }; - - // serialize the request - const payload = JSON.stringify(request); - - // create a uuid for the stream - const streamId = crypto.randomUUID(); - - // store payload in redis - await redisClient.set(`copilot-stream-${streamId}`, payload, 'EX', 60 * 10); // expire in 10 minutes - - return { - streamId, - }; } export async function getCopilotAgentInstructions( diff --git a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts index 80396f2c..620dfd00 100644 --- a/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts +++ b/apps/rowboat/app/api/copilot-stream-response/[streamId]/route.ts @@ -1,70 +1,45 @@ -import { getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing"; -import { USE_BILLING } from "@/app/lib/feature_flags"; -import { redisClient } from "@/app/lib/redis"; -import { CopilotAPIRequest } from "@/src/entities/models/copilot"; -import { streamMultiAgentResponse } from "@/src/application/lib/copilot/copilot"; +import { container } from "@/di/container"; +import { IRunCopilotCachedTurnController } from "@/src/interface-adapters/controllers/copilot/run-copilot-cached-turn.controller"; +import { requireAuth } from "@/app/lib/auth"; export const maxDuration = 300; export async function GET(request: Request, props: { params: Promise<{ streamId: string }> }) { const params = await props.params; - // get the payload from redis - const payload = await redisClient.get(`copilot-stream-${params.streamId}`); - if (!payload) { - return new Response("Stream not found", { status: 404 }); - } - // parse the payload - const { projectId, context, messages, workflow, dataSources } = CopilotAPIRequest.parse(JSON.parse(payload)); + // get user data + const user = await requireAuth(); - // fetch billing customer id - let billingCustomerId: string | null = null; - if (USE_BILLING) { - billingCustomerId = await getCustomerIdForProject(projectId); - } + const runCopilotCachedTurnController = container.resolve("runCopilotCachedTurnController"); - const usageTracker = new UsageTracker(); const encoder = new TextEncoder(); - let messageCount = 0; const stream = new ReadableStream({ async start(controller) { try { // Iterate over the copilot stream generator - for await (const event of streamMultiAgentResponse( - usageTracker, - projectId, - context, - messages, - workflow, - dataSources || [], - )) { + for await (const event of runCopilotCachedTurnController.execute({ + caller: "user", + userId: user.id, + apiKey: request.headers.get("Authorization")?.split(" ")[1], + key: params.streamId, + })) { // Check if this is a content event if ('content' in event) { - messageCount++; controller.enqueue(encoder.encode(`event: message\ndata: ${JSON.stringify(event)}\n\n`)); } else if ('type' in event && event.type === 'tool-call') { controller.enqueue(encoder.encode(`event: tool-call\ndata: ${JSON.stringify(event)}\n\n`)); } else if ('type' in event && event.type === 'tool-result') { controller.enqueue(encoder.encode(`event: tool-result\ndata: ${JSON.stringify(event)}\n\n`)); - } else { - controller.enqueue(encoder.encode(`event: done\ndata: ${JSON.stringify(event)}\n\n`)); } } } catch (error) { console.error('Error processing copilot stream:', error); controller.error(new Error("Something went wrong. Please try again.")); } finally { - // log copilot usage - if (USE_BILLING && billingCustomerId) { - try { - await logUsage(billingCustomerId, { - items: usageTracker.flush(), - }); - } catch (error) { - console.error("Error logging usage", error); - } - } + console.log("closing stream"); + controller.enqueue(encoder.encode(`event: done\ndata: ${JSON.stringify({ type: 'done' })}\n\n`)); + controller.enqueue(encoder.encode("event: end\n\n")); controller.close(); } }, diff --git a/apps/rowboat/di/container.ts b/apps/rowboat/di/container.ts index be9bf344..456250c8 100644 --- a/apps/rowboat/di/container.ts +++ b/apps/rowboat/di/container.ts @@ -145,6 +145,12 @@ import { UpdateLiveWorkflowController } from "@/src/interface-adapters/controlle import { RevertToLiveWorkflowUseCase } from "@/src/application/use-cases/projects/revert-to-live-workflow.use-case"; import { RevertToLiveWorkflowController } from "@/src/interface-adapters/controllers/projects/revert-to-live-workflow.controller"; +// copilot +import { CreateCopilotCachedTurnUseCase } from "@/src/application/use-cases/copilot/create-copilot-cached-turn.use-case"; +import { CreateCopilotCachedTurnController } from "@/src/interface-adapters/controllers/copilot/create-copilot-cached-turn.controller"; +import { RunCopilotCachedTurnUseCase } from "@/src/application/use-cases/copilot/run-copilot-cached-turn.use-case"; +import { RunCopilotCachedTurnController } from "@/src/interface-adapters/controllers/copilot/run-copilot-cached-turn.controller"; + // users import { MongoDBUsersRepository } from "@/src/infrastructure/repositories/mongodb.users.repository"; @@ -328,6 +334,13 @@ container.register({ listConversationsController: asClass(ListConversationsController).singleton(), fetchConversationController: asClass(FetchConversationController).singleton(), + // copilot + // --- + createCopilotCachedTurnUseCase: asClass(CreateCopilotCachedTurnUseCase).singleton(), + createCopilotCachedTurnController: asClass(CreateCopilotCachedTurnController).singleton(), + runCopilotCachedTurnUseCase: asClass(RunCopilotCachedTurnUseCase).singleton(), + runCopilotCachedTurnController: asClass(RunCopilotCachedTurnController).singleton(), + // users // --- usersRepository: asClass(MongoDBUsersRepository).singleton(), diff --git a/apps/rowboat/src/application/lib/copilot/copilot.ts b/apps/rowboat/src/application/lib/copilot/copilot.ts index 79d0ae60..d8b82614 100644 --- a/apps/rowboat/src/application/lib/copilot/copilot.ts +++ b/apps/rowboat/src/application/lib/copilot/copilot.ts @@ -12,6 +12,7 @@ import { CURRENT_WORKFLOW_PROMPT } from "./current_workflow"; import { USE_COMPOSIO_TOOLS } from "@/app/lib/feature_flags"; import { composio, getTool } from "../composio/composio"; import { UsageTracker } from "@/app/lib/billing"; +import { CopilotStreamEvent } from "@/src/entities/models/copilot"; const PROVIDER_API_KEY = process.env.PROVIDER_API_KEY || process.env.OPENAI_API_KEY || ''; const PROVIDER_BASE_URL = process.env.PROVIDER_BASE_URL || undefined; @@ -35,30 +36,6 @@ const openai = createOpenAI({ compatibility: "strict", }); -const ZTextEvent = z.object({ - content: z.string(), -}); - -const ZToolCallEvent = z.object({ - type: z.literal('tool-call'), - toolName: z.string(), - toolCallId: z.string(), - args: z.record(z.any()), - query: z.string().optional(), -}); - -const ZToolResultEvent = z.object({ - type: z.literal('tool-result'), - toolCallId: z.string(), - result: z.any(), -}); - -const ZDoneEvent = z.object({ - done: z.literal(true), -}); - -const ZEvent = z.union([ZTextEvent, ZToolCallEvent, ZToolResultEvent, ZDoneEvent]); - const composioToolSearchToolSuggestion = z.object({ toolkit: z.string(), tool_slug: z.string(), @@ -273,7 +250,7 @@ export async function* streamMultiAgentResponse( messages: z.infer[], workflow: z.infer, dataSources: z.infer[] -): AsyncIterable> { +): AsyncIterable> { const logger = new PrefixLogger('copilot /stream'); logger.log('context', context); logger.log('projectId', projectId); @@ -375,9 +352,4 @@ export async function* streamMultiAgentResponse( projectId, totalChunks: chunkCount }); - - // done - yield { - done: true, - }; } \ No newline at end of file diff --git a/apps/rowboat/src/application/use-cases/copilot/create-copilot-cached-turn.use-case.ts b/apps/rowboat/src/application/use-cases/copilot/create-copilot-cached-turn.use-case.ts new file mode 100644 index 00000000..20f70a92 --- /dev/null +++ b/apps/rowboat/src/application/use-cases/copilot/create-copilot-cached-turn.use-case.ts @@ -0,0 +1,87 @@ +import { z } from "zod"; +import { nanoid } from 'nanoid'; +import { ICacheService } from '@/src/application/services/cache.service.interface'; +import { IUsageQuotaPolicy } from '@/src/application/policies/usage-quota.policy.interface'; +import { IProjectActionAuthorizationPolicy } from '@/src/application/policies/project-action-authorization.policy'; +import { CopilotChatContext, CopilotMessage, DataSourceSchemaForCopilot } from '@/src/entities/models/copilot'; +import { Workflow } from '@/app/lib/types/workflow_types'; +import { USE_BILLING } from "@/app/lib/feature_flags"; +import { authorize, getCustomerIdForProject } from "@/app/lib/billing"; +import { BillingError } from "@/src/entities/errors/common"; + +const inputSchema = z.object({ + caller: z.enum(["user", "api"]), + userId: z.string().optional(), + apiKey: z.string().optional(), + data: z.object({ + projectId: z.string(), + messages: z.array(CopilotMessage), + workflow: Workflow, + context: CopilotChatContext.nullable(), + dataSources: z.array(DataSourceSchemaForCopilot).optional(), + }), +}); + +export interface ICreateCopilotCachedTurnUseCase { + execute(data: z.infer): Promise<{ key: string }>; +} + +export class CreateCopilotCachedTurnUseCase implements ICreateCopilotCachedTurnUseCase { + private readonly cacheService: ICacheService; + private readonly usageQuotaPolicy: IUsageQuotaPolicy; + private readonly projectActionAuthorizationPolicy: IProjectActionAuthorizationPolicy; + + constructor({ + cacheService, + usageQuotaPolicy, + projectActionAuthorizationPolicy, + }: { + cacheService: ICacheService, + usageQuotaPolicy: IUsageQuotaPolicy, + projectActionAuthorizationPolicy: IProjectActionAuthorizationPolicy, + }) { + this.cacheService = cacheService; + this.usageQuotaPolicy = usageQuotaPolicy; + this.projectActionAuthorizationPolicy = projectActionAuthorizationPolicy; + } + + async execute(data: z.infer): Promise<{ key: string }> { + const { projectId } = data.data; + + // check auth + await this.projectActionAuthorizationPolicy.authorize({ + projectId, + caller: data.caller, + userId: data.userId, + apiKey: data.apiKey, + }); + await this.usageQuotaPolicy.assertAndConsumeProjectAction(projectId); + + // check billing authorization + if (USE_BILLING) { + // get billing customer id for this project + const billingCustomerId = await getCustomerIdForProject(projectId); + + // validate enough credits + const result = await authorize(billingCustomerId, { + type: "use_credits" + }); + if (!result.success) { + throw new BillingError(result.error || 'Billing error'); + } + } + + // serialize request + const payload = JSON.stringify(data.data); + + // create unique id for stream + const key = nanoid(); + + // store in cache + await this.cacheService.set(`copilot-stream-${key}`, payload, 60 * 10); // expire in 10 minutes + + return { + key, + } + } +} \ No newline at end of file diff --git a/apps/rowboat/src/application/use-cases/copilot/run-copilot-cached-turn.use-case.ts b/apps/rowboat/src/application/use-cases/copilot/run-copilot-cached-turn.use-case.ts new file mode 100644 index 00000000..7afb0a29 --- /dev/null +++ b/apps/rowboat/src/application/use-cases/copilot/run-copilot-cached-turn.use-case.ts @@ -0,0 +1,104 @@ +import { z } from "zod"; +import { ICacheService } from '@/src/application/services/cache.service.interface'; +import { IUsageQuotaPolicy } from '@/src/application/policies/usage-quota.policy.interface'; +import { IProjectActionAuthorizationPolicy } from '@/src/application/policies/project-action-authorization.policy'; +import { CopilotAPIRequest, CopilotStreamEvent } from '@/src/entities/models/copilot'; +import { USE_BILLING } from "@/app/lib/feature_flags"; +import { authorize, getCustomerIdForProject, logUsage, UsageTracker } from "@/app/lib/billing"; +import { BillingError, NotFoundError } from "@/src/entities/errors/common"; +import { streamMultiAgentResponse } from "@/src/application/lib/copilot/copilot"; + +const inputSchema = z.object({ + caller: z.enum(["user", "api"]), + userId: z.string().optional(), + apiKey: z.string().optional(), + key: z.string(), +}); + +export interface IRunCopilotCachedTurnUseCase { + execute(data: z.infer): AsyncGenerator, void, unknown>; +} + +export class RunCopilotCachedTurnUseCase implements IRunCopilotCachedTurnUseCase { + private readonly cacheService: ICacheService; + private readonly usageQuotaPolicy: IUsageQuotaPolicy; + private readonly projectActionAuthorizationPolicy: IProjectActionAuthorizationPolicy; + + constructor({ + cacheService, + usageQuotaPolicy, + projectActionAuthorizationPolicy, + }: { + cacheService: ICacheService, + usageQuotaPolicy: IUsageQuotaPolicy, + projectActionAuthorizationPolicy: IProjectActionAuthorizationPolicy, + }) { + this.cacheService = cacheService; + this.usageQuotaPolicy = usageQuotaPolicy; + this.projectActionAuthorizationPolicy = projectActionAuthorizationPolicy; + } + + async *execute(data: z.infer): AsyncGenerator, void, unknown> { + // fetch cached turn + const lookupKey = `copilot-stream-${data.key}`; + const payload = await this.cacheService.get(lookupKey); + if (!payload) { + throw new NotFoundError('Cached turn not found'); + } + + // delete from cache + await this.cacheService.delete(lookupKey); + + // parse cached turn + const cachedTurn = CopilotAPIRequest.parse(JSON.parse(payload)); + + const { projectId } = cachedTurn; + + // check auth + await this.projectActionAuthorizationPolicy.authorize({ + projectId, + caller: data.caller, + userId: data.userId, + apiKey: data.apiKey, + }); + + await this.usageQuotaPolicy.assertAndConsumeProjectAction(projectId); + + // check billing authorization + let billingCustomerId: string | null = null; + if (USE_BILLING) { + // get billing customer id for this project + billingCustomerId = await getCustomerIdForProject(projectId); + + // validate enough credits + const result = await authorize(billingCustomerId, { + type: "use_credits" + }); + if (!result.success) { + throw new BillingError(result.error || 'Billing error'); + } + } + + // init usage tracking + const usageTracker = new UsageTracker(); + + try { + for await (const event of streamMultiAgentResponse( + usageTracker, + projectId, + cachedTurn.context, + cachedTurn.messages, + cachedTurn.workflow, + cachedTurn.dataSources || [], + )) { + yield event; + } + } finally { + if (USE_BILLING && billingCustomerId) { + await logUsage(billingCustomerId, { + items: usageTracker.flush(), + }); + } + } + } +} \ No newline at end of file diff --git a/apps/rowboat/src/entities/models/copilot.ts b/apps/rowboat/src/entities/models/copilot.ts index 5b54f9dc..eeff9742 100644 --- a/apps/rowboat/src/entities/models/copilot.ts +++ b/apps/rowboat/src/entities/models/copilot.ts @@ -68,4 +68,28 @@ export const CopilotAPIResponse = z.union([ z.object({ error: z.string(), }), +]); + +const CopilotStreamTextEvent = z.object({ + content: z.string(), +}); + +const CopilotStreamToolCallEvent = z.object({ + type: z.literal('tool-call'), + toolName: z.string(), + toolCallId: z.string(), + args: z.record(z.any()), + query: z.string().optional(), +}); + +const CopilotStreamToolResultEvent = z.object({ + type: z.literal('tool-result'), + toolCallId: z.string(), + result: z.any(), +}); + +export const CopilotStreamEvent = z.union([ + CopilotStreamTextEvent, + CopilotStreamToolCallEvent, + CopilotStreamToolResultEvent, ]); \ No newline at end of file diff --git a/apps/rowboat/src/interface-adapters/controllers/copilot/create-copilot-cached-turn.controller.ts b/apps/rowboat/src/interface-adapters/controllers/copilot/create-copilot-cached-turn.controller.ts new file mode 100644 index 00000000..bd14730e --- /dev/null +++ b/apps/rowboat/src/interface-adapters/controllers/copilot/create-copilot-cached-turn.controller.ts @@ -0,0 +1,44 @@ +import { z } from "zod"; +import { CopilotChatContext, CopilotMessage, DataSourceSchemaForCopilot } from '@/src/entities/models/copilot'; +import { Workflow } from '@/app/lib/types/workflow_types'; +import { ICreateCopilotCachedTurnUseCase } from "@/src/application/use-cases/copilot/create-copilot-cached-turn.use-case"; +import { BadRequestError } from "@/src/entities/errors/common"; + +const inputSchema = z.object({ + caller: z.enum(["user", "api"]), + userId: z.string().optional(), + apiKey: z.string().optional(), + data: z.object({ + projectId: z.string(), + messages: z.array(CopilotMessage), + workflow: Workflow, + context: CopilotChatContext.nullable(), + dataSources: z.array(DataSourceSchemaForCopilot).optional(), + }), +}); + +export interface ICreateCopilotCachedTurnController { + execute(request: z.infer): Promise<{ key: string }>; +} + +export class CreateCopilotCachedTurnController implements ICreateCopilotCachedTurnController { + private readonly createCopilotCachedTurnUseCase: ICreateCopilotCachedTurnUseCase; + + constructor({ + createCopilotCachedTurnUseCase, + }: { + createCopilotCachedTurnUseCase: ICreateCopilotCachedTurnUseCase, + }) { + this.createCopilotCachedTurnUseCase = createCopilotCachedTurnUseCase; + } + + async execute(request: z.infer): Promise<{ key: string }> { + // parse input + const result = inputSchema.safeParse(request); + if (!result.success) { + throw new BadRequestError(`Invalid request: ${JSON.stringify(result.error)}`); + } + + return await this.createCopilotCachedTurnUseCase.execute(result.data); + } +} \ No newline at end of file diff --git a/apps/rowboat/src/interface-adapters/controllers/copilot/run-copilot-cached-turn.controller.ts b/apps/rowboat/src/interface-adapters/controllers/copilot/run-copilot-cached-turn.controller.ts new file mode 100644 index 00000000..d62ad3fd --- /dev/null +++ b/apps/rowboat/src/interface-adapters/controllers/copilot/run-copilot-cached-turn.controller.ts @@ -0,0 +1,37 @@ +import { z } from "zod"; +import { CopilotStreamEvent } from '@/src/entities/models/copilot'; +import { IRunCopilotCachedTurnUseCase } from "@/src/application/use-cases/copilot/run-copilot-cached-turn.use-case"; +import { BadRequestError } from "@/src/entities/errors/common"; + +const inputSchema = z.object({ + caller: z.enum(["user", "api"]), + userId: z.string().optional(), + apiKey: z.string().optional(), + key: z.string(), +}); + +export interface IRunCopilotCachedTurnController { + execute(request: z.infer): AsyncGenerator, void, unknown>; +} + +export class RunCopilotCachedTurnController implements IRunCopilotCachedTurnController { + private readonly runCopilotCachedTurnUseCase: IRunCopilotCachedTurnUseCase; + + constructor({ + runCopilotCachedTurnUseCase, + }: { + runCopilotCachedTurnUseCase: IRunCopilotCachedTurnUseCase, + }) { + this.runCopilotCachedTurnUseCase = runCopilotCachedTurnUseCase; + } + + async *execute(request: z.infer): AsyncGenerator, void, unknown> { + // parse input + const result = inputSchema.safeParse(request); + if (!result.success) { + throw new BadRequestError(`Invalid request: ${JSON.stringify(result.error)}`); + } + + yield *this.runCopilotCachedTurnUseCase.execute(result.data); + } +} \ No newline at end of file