diff --git a/apps/python-sdk/pyproject.toml b/apps/python-sdk/pyproject.toml index 70478f25..7a550f8e 100644 --- a/apps/python-sdk/pyproject.toml +++ b/apps/python-sdk/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rowboat" -version = "4.0.0" +version = "5.0.0" authors = [ { name = "Ramnique Singh", email = "ramnique@rowboatlabs.com" }, ] diff --git a/apps/python-sdk/src/rowboat/__init__.py b/apps/python-sdk/src/rowboat/__init__.py index 6eff958d..b431d95f 100644 --- a/apps/python-sdk/src/rowboat/__init__.py +++ b/apps/python-sdk/src/rowboat/__init__.py @@ -1,4 +1,4 @@ -from .client import Client, StatefulChat +from .client import Client from .schema import ( ApiMessage, UserMessage, @@ -8,21 +8,4 @@ from .schema import ( ToolMessage, ApiRequest, ApiResponse -) - -__version__ = "0.1.0" - -__all__ = [ - "Client", - "StatefulChat", - # Message types - "ApiMessage", - "UserMessage", - "SystemMessage", - "AssistantMessage", - "AssistantMessageWithToolCalls", - "ToolMessage", - # Request/Response types - "ApiRequest", - "ApiResponse", -] \ No newline at end of file +) \ No newline at end of file diff --git a/apps/python-sdk/src/rowboat/client.py b/apps/python-sdk/src/rowboat/client.py index 69703ff1..a958f871 100644 --- a/apps/python-sdk/src/rowboat/client.py +++ b/apps/python-sdk/src/rowboat/client.py @@ -1,36 +1,30 @@ -from typing import Dict, List, Optional, Any, Union +from typing import Dict, List, Optional import requests from .schema import ( ApiRequest, ApiResponse, ApiMessage, UserMessage, - AssistantMessage, - AssistantMessageWithToolCalls ) class Client: - def __init__(self, host: str, project_id: str, api_key: str) -> None: - self.base_url: str = f'{host}/api/v1/{project_id}/chat' + def __init__(self, host: str, projectId: str, apiKey: str) -> None: + self.base_url: str = f'{host}/api/v1/{projectId}/chat' self.headers: Dict[str, str] = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {api_key}' + 'Authorization': f'Bearer {apiKey}' } def _call_api( self, messages: List[ApiMessage], - state: Optional[Dict[str, Any]] = None, - workflow_id: Optional[str] = None, - test_profile_id: Optional[str] = None, - mock_tools: Optional[Dict[str, str]] = None + conversationId: Optional[str] = None, + mockTools: Optional[Dict[str, str]] = None ) -> ApiResponse: request = ApiRequest( messages=messages, - state=state, - workflowId=workflow_id, - testProfileId=test_profile_id, - mockTools=mock_tools + conversationId=conversationId, + mockTools=mockTools ) json_data = request.model_dump() response = requests.post(self.base_url, headers=self.headers, json=json_data) @@ -38,86 +32,23 @@ class Client: if not response.status_code == 200: raise ValueError(f"Error: {response.status_code} - {response.text}") - response_data = ApiResponse.model_validate(response.json()) - - if not response_data.messages: - raise ValueError("No response") - - last_message = response_data.messages[-1] - if not isinstance(last_message, (AssistantMessage, AssistantMessageWithToolCalls)): - raise ValueError("Last message was not an assistant message") + return ApiResponse.model_validate(response.json()) - return response_data - - def chat( + def run_turn( self, messages: List[ApiMessage], - state: Optional[Dict[str, Any]] = None, - workflow_id: Optional[str] = None, - test_profile_id: Optional[str] = None, - mock_tools: Optional[Dict[str, str]] = None, + conversationId: Optional[str] = None, + mockTools: Optional[Dict[str, str]] = None, ) -> ApiResponse: """Stateless chat method that handles a single conversation turn""" # call api - response_data = self._call_api( + return self._call_api( messages=messages, - state=state, - workflow_id=workflow_id, - test_profile_id=test_profile_id, - mock_tools=mock_tools, + conversationId=conversationId, + mockTools=mockTools, ) - if not response_data.messages[-1].responseType == 'external': - raise ValueError("Last message was not an external message") - - return response_data - -class StatefulChat: - """Maintains conversation state across multiple turns""" - - def __init__( - self, - client: Client, - workflow_id: Optional[str] = None, - test_profile_id: Optional[str] = None, - mock_tools: Optional[Dict[str, str]] = None, - ) -> None: - self.client = client - self.messages: List[ApiMessage] = [] - self.state: Optional[Dict[str, Any]] = None - self.workflow_id = workflow_id - self.test_profile_id = test_profile_id - self.mock_tools = mock_tools - - def run(self, message: Union[str]) -> str: - """Handle a single user turn in the conversation""" - - # Process the message - user_msg = UserMessage(role='user', content=message) - self.messages.append(user_msg) - - # Get response using the client's chat method - response_data = self.client.chat( - messages=self.messages, - state=self.state, - workflow_id=self.workflow_id, - test_profile_id=self.test_profile_id, - mock_tools=self.mock_tools, - ) - - # Update internal state - self.messages.extend(response_data.messages) - self.state = response_data.state - - # Return only the final message content - last_message = self.messages[-1] - return last_message.content - - -def weather_lookup_tool(city_name: str) -> str: - return f"The weather in {city_name} is 22°C." - if __name__ == "__main__": host: str = "" @@ -125,13 +56,18 @@ if __name__ == "__main__": api_key: str = "" client = Client(host, project_id, api_key) - result = client.chat( + result = client.run_turn( messages=[ - UserMessage(role='user', content="Hello") + UserMessage(role='user', content="list my github repos") ] ) - print(result.messages[-1].content) + print(result.turn.output[-1].content) + print(result.conversationId) - chat_session = StatefulChat(client) - resp = chat_session.run("Hello") - print(resp) \ No newline at end of file + result = client.run_turn( + messages=[ + UserMessage(role='user', content="how many did you find?") + ], + conversationId=result.conversationId + ) + print(result.turn.output[-1].content) \ No newline at end of file diff --git a/apps/python-sdk/src/rowboat/schema.py b/apps/python-sdk/src/rowboat/schema.py index 62c07fc2..bd53ebee 100644 --- a/apps/python-sdk/src/rowboat/schema.py +++ b/apps/python-sdk/src/rowboat/schema.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Any, Literal, Dict +from typing import List, Optional, Union, Literal, Dict from pydantic import BaseModel class SystemMessage(BaseModel): @@ -44,13 +44,15 @@ ApiMessage = Union[ ToolMessage ] +class Turn(BaseModel): + id: str + output: List[ApiMessage] + class ApiRequest(BaseModel): + conversationId: Optional[str] = None messages: List[ApiMessage] - state: Any - workflowId: Optional[str] = None - testProfileId: Optional[str] = None mockTools: Optional[Dict[str, str]] = None class ApiResponse(BaseModel): - messages: List[ApiMessage] - state: Optional[Any] = None \ No newline at end of file + conversationId: str + turn: Turn \ No newline at end of file diff --git a/apps/rowboat/app/actions/actions.ts b/apps/rowboat/app/actions/actions.ts deleted file mode 100644 index d0d2fea3..00000000 --- a/apps/rowboat/app/actions/actions.ts +++ /dev/null @@ -1,38 +0,0 @@ -'use server'; -import { z } from 'zod'; -import { getAgenticResponseStreamId } from "../lib/utils"; -import { check_query_limit } from "../lib/rate_limiting"; -import { QueryLimitError } from "../lib/client_utils"; -import { projectAuthCheck } from "./project_actions"; -import { authorizeUserAction } from "./billing_actions"; -import { Workflow } from "../lib/types/workflow_types"; -import { Message } from "@/app/lib/types/types"; - -export async function getAssistantResponseStreamId( - projectId: string, - workflow: z.infer, - messages: z.infer[], -): Promise<{ streamId: string } | { billingError: string }> { - await projectAuthCheck(projectId); - if (!await check_query_limit(projectId)) { - throw new QueryLimitError(); - } - - // Check billing authorization - const agentModels = workflow.agents.reduce((acc, agent) => { - acc.push(agent.model); - return acc; - }, [] as string[]); - const { success, error } = await authorizeUserAction({ - type: 'agent_response', - data: { - agentModels, - }, - }); - if (!success) { - return { billingError: error || 'Billing error' }; - } - - const response = await getAgenticResponseStreamId(projectId, workflow, messages); - return response; -} \ No newline at end of file diff --git a/apps/rowboat/app/actions/copilot_actions.ts b/apps/rowboat/app/actions/copilot_actions.ts index 84bd0828..91a9c28c 100644 --- a/apps/rowboat/app/actions/copilot_actions.ts +++ b/apps/rowboat/app/actions/copilot_actions.ts @@ -8,7 +8,7 @@ import { import { DataSource } from "../lib/types/datasource_types"; import { z } from 'zod'; import { check_query_limit } from "../lib/rate_limiting"; -import { QueryLimitError } from "../lib/client_utils"; +import { QueryLimitError } from "@/src/entities/errors/common"; import { projectAuthCheck } from "./project_actions"; import { redisClient } from "../lib/redis"; import { authorizeUserAction, logUsage } from "./billing_actions"; diff --git a/apps/rowboat/app/actions/playground-chat.actions.ts b/apps/rowboat/app/actions/playground-chat.actions.ts new file mode 100644 index 00000000..8297c2d4 --- /dev/null +++ b/apps/rowboat/app/actions/playground-chat.actions.ts @@ -0,0 +1,54 @@ +'use server'; +import { z } from 'zod'; +import { Workflow } from "../lib/types/workflow_types"; +import { Message } from "@/app/lib/types/types"; +import { authCheck } from './auth_actions'; +import { container } from '@/di/container'; +import { Conversation } from '@/src/entities/models/conversation'; +import { ICreatePlaygroundConversationController } from '@/src/interface-adapters/controllers/conversations/create-playground-conversation.controller'; +import { ICreateCachedTurnController } from '@/src/interface-adapters/controllers/conversations/create-cached-turn.controller'; + +export async function createConversation({ + projectId, + workflow, + isLiveWorkflow, +}: { + projectId: string; + workflow: z.infer; + isLiveWorkflow: boolean; +}): Promise> { + const user = await authCheck(); + + const controller = container.resolve("createPlaygroundConversationController"); + + return await controller.execute({ + userId: user._id, + projectId, + workflow, + isLiveWorkflow, + }); +} + +export async function createCachedTurn({ + conversationId, + messages, +}: { + conversationId: string; + messages: z.infer[]; +}): Promise<{ key: string }> { + const user = await authCheck(); + const createCachedTurnController = container.resolve("createCachedTurnController"); + + const { key } = await createCachedTurnController.execute({ + caller: "user", + userId: user._id, + conversationId, + input: { + messages, + }, + }); + + return { + key, + }; +} \ No newline at end of file diff --git a/apps/rowboat/app/api/stream-response/[streamId]/route.ts b/apps/rowboat/app/api/stream-response/[streamId]/route.ts index 743acf53..2f4f670a 100644 --- a/apps/rowboat/app/api/stream-response/[streamId]/route.ts +++ b/apps/rowboat/app/api/stream-response/[streamId]/route.ts @@ -1,67 +1,41 @@ -import { getCustomerIdForProject, logUsage } from "@/app/lib/billing"; -import { USE_BILLING } from "@/app/lib/feature_flags"; -import { redisClient } from "@/app/lib/redis"; -import { streamResponse } from "@/app/lib/agents"; -import { ZStreamAgentResponsePayload } from "@/app/lib/types/types"; +import { container } from "@/di/container"; +import { IRunCachedTurnController } from "@/src/interface-adapters/controllers/conversations/run-cached-turn.controller"; +import { requireAuth } from "@/app/lib/auth"; export async function GET(request: Request, props: { params: Promise<{ streamId: string }> }) { - const params = await props.params; - // get the payload from redis - const payload = await redisClient.get(`chat-stream-${params.streamId}`); - if (!payload) { - return new Response("Stream not found", { status: 404 }); - } - - // parse the payload - const { projectId, workflow, messages } = ZStreamAgentResponsePayload.parse(JSON.parse(payload)); - console.log('payload', payload); - - // fetch billing customer id - let billingCustomerId: string | null = null; - if (USE_BILLING) { - billingCustomerId = await getCustomerIdForProject(projectId); - } - - const encoder = new TextEncoder(); - let messageCount = 0; - - const stream = new ReadableStream({ - async start(controller) { - try { - // Iterate over the generator - for await (const event of streamResponse(projectId, workflow, messages)) { - // Check if this is a message event (has role property) - if ('role' in event) { - if (event.role === 'assistant') { - messageCount++; + const params = await props.params; + + // get user data + const user = await requireAuth(); + + const runCachedTurnController = container.resolve("runCachedTurnController"); + + const encoder = new TextEncoder(); + + const stream = new ReadableStream({ + async start(controller) { + try { + // Iterate over the generator + for await (const event of runCachedTurnController.execute({ + caller: "user", + userId: user._id, + cachedTurnKey: params.streamId, + })) { + controller.enqueue(encoder.encode(`event: message\ndata: ${JSON.stringify(event)}\n\n`)); + } + controller.close(); + } catch (error) { + console.error('Error processing stream:', error); + controller.error(error); } - controller.enqueue(encoder.encode(`event: message\ndata: ${JSON.stringify(event)}\n\n`)); - } else { - controller.enqueue(encoder.encode(`event: done\ndata: ${JSON.stringify(event)}\n\n`)); - } - } - - controller.close(); - - // Log billing usage - if (USE_BILLING && billingCustomerId) { - await logUsage(billingCustomerId, { - type: "agent_messages", - amount: messageCount, - }); - } - } catch (error) { - console.error('Error processing stream:', error); - controller.error(error); - } - }, - }); - - return new Response(stream, { - headers: { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - }); + }, + }); + + return new Response(stream, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + }); } \ No newline at end of file diff --git a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts index 8008fae3..73a48348 100644 --- a/apps/rowboat/app/api/v1/[projectId]/chat/route.ts +++ b/apps/rowboat/app/api/v1/[projectId]/chat/route.ts @@ -1,14 +1,10 @@ import { NextRequest } from "next/server"; -import { projectsCollection } from "../../../../lib/mongodb"; import { z } from "zod"; -import { ObjectId } from "mongodb"; -import { authCheck } from "../../utils"; -import { ApiRequest, ApiResponse } from "../../../../lib/types/types"; -import { check_query_limit } from "../../../../lib/rate_limiting"; +import { ApiResponse } from "@/app/lib/types/api_types"; +import { ApiRequest } from "@/app/lib/types/api_types"; import { PrefixLogger } from "../../../../lib/utils"; -import { authorize, getCustomerIdForProject, logUsage } from "@/app/lib/billing"; -import { USE_BILLING } from "@/app/lib/feature_flags"; -import { getResponse } from "@/app/lib/agents"; +import { container } from "@/di/container"; +import { IRunTurnController } from "@/src/interface-adapters/controllers/conversations/run-turn.controller"; // get next turn / agent response export async function POST( @@ -19,91 +15,69 @@ export async function POST( const requestId = crypto.randomUUID(); const logger = new PrefixLogger(`${requestId}`); - logger.log(`Got chat request for project ${projectId}`); + // parse and validate the request body + let data; + try { + const body = await req.json(); + data = ApiRequest.parse(body); + } catch (e) { + logger.log(`Invalid JSON in request body: ${e}`); + return Response.json({ error: "Invalid request" }, { status: 400 }); + } + const { conversationId, messages, mockTools, stream } = data; - // check query limit - if (!await check_query_limit(projectId)) { - logger.log(`Query limit exceeded for project ${projectId}`); - return Response.json({ error: "Query limit exceeded" }, { status: 429 }); + const runTurnController = container.resolve("runTurnController"); + + // get assistant response + const response = await runTurnController.execute({ + caller: "api", + apiKey: req.headers.get("Authorization")?.split(" ")[1], + projectId, + input: { + messages, + mockTools, + }, + conversationId: conversationId || undefined, + stream: Boolean(stream), + }); + + // if streaming is requested, return SSE stream + if (stream && 'stream' in response) { + const encoder = new TextEncoder(); + + const readableStream = new ReadableStream({ + async start(controller) { + try { + // Iterate over the generator + for await (const event of response.stream) { + controller.enqueue(encoder.encode(`event: message\ndata: ${JSON.stringify(event)}\n\n`)); + } + controller.close(); + } catch (error) { + logger.log(`Error processing stream: ${error}`); + controller.error(error); + } + }, + }); + + return new Response(readableStream, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + }); } - return await authCheck(projectId, req, async () => { - // fetch billing customer id - let billingCustomerId: string | null = null; - if (USE_BILLING) { - billingCustomerId = await getCustomerIdForProject(projectId); - } + // non-streaming response (existing behavior) + if (!('turn' in response)) { + logger.log(`No turn data found in response`); + return Response.json({ error: "No turn data found in response" }, { status: 500 }); + } - // parse and validate the request body - let body; - try { - body = await req.json(); - } catch (e) { - logger.log(`Invalid JSON in request body: ${e}`); - return Response.json({ error: "Invalid JSON in request body" }, { status: 400 }); - } - logger.log(`Request json: ${JSON.stringify(body, null, 2)}`); - const result = ApiRequest.safeParse(body); - if (!result.success) { - logger.log(`Invalid request body: ${result.error.message}`); - return Response.json({ error: `Invalid request body: ${result.error.message}` }, { status: 400 }); - } - const reqMessages = result.data.messages; - const mockToolOverrides = result.data.mockTools; - - // fetch published workflow id - const project = await projectsCollection.findOne({ - _id: projectId, - }); - if (!project) { - logger.log(`Project ${projectId} not found`); - return Response.json({ error: "Project not found" }, { status: 404 }); - } - - // fetch workflow - const workflow = project.liveWorkflow; - if (!workflow) { - logger.log(`Workflow not found for project ${projectId}`); - return Response.json({ error: "Workflow not found" }, { status: 404 }); - } - - // override mock instructions - if (mockToolOverrides) { - workflow.mockTools = mockToolOverrides; - } - - // check billing authorization - if (USE_BILLING && billingCustomerId) { - const agentModels = workflow.agents.reduce((acc, agent) => { - acc.push(agent.model); - return acc; - }, [] as string[]); - const response = await authorize(billingCustomerId, { - type: 'agent_response', - data: { - agentModels, - }, - }); - if (!response.success) { - return Response.json({ error: response.error || 'Billing error' }, { status: 402 }); - } - } - - // get assistant response - const { messages } = await getResponse(projectId, workflow, reqMessages); - - // log billing usage - if (USE_BILLING && billingCustomerId) { - const agentMessageCount = messages.filter(m => m.role === 'assistant').length; - await logUsage(billingCustomerId, { - type: 'agent_messages', - amount: agentMessageCount, - }); - } - - const responseBody: z.infer = { - messages, - }; - return Response.json(responseBody); - }); + const responseBody: z.infer = { + conversationId: response.conversationId, + turn: response.turn, + }; + return Response.json(responseBody); } diff --git a/apps/rowboat/app/lib/billing.ts b/apps/rowboat/app/lib/billing.ts index dd7140ac..d6d880ba 100644 --- a/apps/rowboat/app/lib/billing.ts +++ b/apps/rowboat/app/lib/billing.ts @@ -23,13 +23,6 @@ const GUEST_BILLING_CUSTOMER = { updatedAt: new Date().toISOString(), }; -export class BillingError extends Error { - constructor(message: string) { - super(message); - this.name = 'BillingError'; - } -} - export async function getCustomerIdForProject(projectId: string): Promise { const project = await projectsCollection.findOne({ _id: projectId }); if (!project) { diff --git a/apps/rowboat/app/lib/client_utils.ts b/apps/rowboat/app/lib/client_utils.ts index b190a914..0e369a5c 100644 --- a/apps/rowboat/app/lib/client_utils.ts +++ b/apps/rowboat/app/lib/client_utils.ts @@ -1,13 +1,6 @@ import { WorkflowTool, WorkflowAgent, WorkflowPrompt } from "./types/workflow_types"; import { z } from "zod"; -export class QueryLimitError extends Error { - constructor(message: string = 'Query limit exceeded') { - super(message); - this.name = 'QueryLimitError'; - } -} - export function validateConfigChanges(configType: string, configChanges: Record, name: string) { let testObject: any; let schema: z.ZodType; diff --git a/apps/rowboat/app/lib/types/api_types.ts b/apps/rowboat/app/lib/types/api_types.ts new file mode 100644 index 00000000..b16faebc --- /dev/null +++ b/apps/rowboat/app/lib/types/api_types.ts @@ -0,0 +1,14 @@ +import { Message } from "./types"; +import { Turn } from "@/src/entities/models/turn"; +import { z } from "zod"; + +export const ApiRequest = z.object({ + messages: z.array(Message), + conversationId: z.string().nullable().optional(), + mockTools: z.record(z.string(), z.string()).nullable().optional(), + stream: z.boolean().optional().nullable().default(false), +});export const ApiResponse = z.object({ + turn: Turn, + conversationId: z.string().optional(), +}); + diff --git a/apps/rowboat/app/lib/types/types.ts b/apps/rowboat/app/lib/types/types.ts index c03e9ffd..ab9de57a 100644 --- a/apps/rowboat/app/lib/types/types.ts +++ b/apps/rowboat/app/lib/types/types.ts @@ -1,24 +1,28 @@ import { z } from "zod"; -import { Workflow, WorkflowTool } from "./workflow_types"; +import { WorkflowTool } from "./workflow_types"; -export const SystemMessage = z.object({ +export const BaseMessage = z.object({ + timestamp: z.string().datetime().optional(), +}); + +export const SystemMessage = BaseMessage.extend({ role: z.literal("system"), content: z.string(), }); -export const UserMessage = z.object({ +export const UserMessage = BaseMessage.extend({ role: z.literal("user"), content: z.string(), }); -export const AssistantMessage = z.object({ +export const AssistantMessage = BaseMessage.extend({ role: z.literal("assistant"), content: z.string(), agentName: z.string().nullable(), responseType: z.enum(['internal', 'external']), }); -export const AssistantMessageWithToolCalls = z.object({ +export const AssistantMessageWithToolCalls = BaseMessage.extend({ role: z.literal("assistant"), content: z.null(), toolCalls: z.array(z.object({ @@ -32,7 +36,7 @@ export const AssistantMessageWithToolCalls = z.object({ agentName: z.string().nullable(), }); -export const ToolMessage = z.object({ +export const ToolMessage = BaseMessage.extend({ role: z.literal("tool"), content: z.string(), toolCallId: z.string(), @@ -143,18 +147,6 @@ export const ChatClientId = z.object({ export type WithStringId = T & { _id: string }; -export const ApiRequest = z.object({ - messages: z.array(Message), - state: z.unknown(), - testProfileId: z.string().nullable().optional(), - mockTools: z.record(z.string(), z.string()).nullable().optional(), -}); - -export const ApiResponse = z.object({ - messages: z.array(Message), - state: z.unknown(), -}); - // Helper function to convert MCP server tool to WorkflowTool export function convertMcpServerToolToWorkflowTool( mcpTool: z.infer, @@ -194,9 +186,4 @@ export function convertMcpServerToolToWorkflowTool( }; return converted; -} -export const ZStreamAgentResponsePayload = z.object({ - projectId: z.string(), - workflow: Workflow, - messages: z.array(Message), -}); +} \ No newline at end of file diff --git a/apps/rowboat/app/lib/utils.ts b/apps/rowboat/app/lib/utils.ts index 4d6ee2ff..3fe752e6 100644 --- a/apps/rowboat/app/lib/utils.ts +++ b/apps/rowboat/app/lib/utils.ts @@ -1,35 +1,7 @@ import { z } from "zod"; import { generateObject } from "ai"; import { openai } from "@ai-sdk/openai"; -import { redisClient } from "./redis"; -import { Workflow, WorkflowTool } from "./types/workflow_types"; -import { Message, ZStreamAgentResponsePayload } from "./types/types"; - -export async function getAgenticResponseStreamId( - projectId: string, - workflow: z.infer, - messages: z.infer[], -): Promise<{ - streamId: string, -}> { - const payload: z.infer = { - projectId, - workflow, - messages, - } - // serialize the request - const serialized = JSON.stringify(payload); - - // create a uuid for the stream - const streamId = crypto.randomUUID(); - - // store payload in redis - await redisClient.set(`chat-stream-${streamId}`, serialized, 'EX', 60 * 10); // expire in 10 minutes - - return { - streamId, - }; -} +import { Message } from "./types/types"; // create a PrefixLogger class that wraps console.log with a prefix // and allows chaining with a parent logger diff --git a/apps/rowboat/app/projects/[projectId]/playground/app.tsx b/apps/rowboat/app/projects/[projectId]/playground/app.tsx index 369cf980..5366555b 100644 --- a/apps/rowboat/app/projects/[projectId]/playground/app.tsx +++ b/apps/rowboat/app/projects/[projectId]/playground/app.tsx @@ -16,6 +16,7 @@ export function App({ messageSubscriber, onPanelClick, triggerCopilotChat, + isLiveWorkflow, }: { hidden?: boolean; projectId: string; @@ -23,6 +24,7 @@ export function App({ messageSubscriber?: (messages: z.infer[]) => void; onPanelClick?: () => void; triggerCopilotChat?: (message: string) => void; + isLiveWorkflow: boolean; }) { const [counter, setCounter] = useState(0); const [showDebugMessages, setShowDebugMessages] = useState(true); @@ -118,6 +120,7 @@ export function App({ onCopyClick={(fn) => { getCopyContentRef.current = fn; }} showDebugMessages={showDebugMessages} triggerCopilotChat={triggerCopilotChat} + isLiveWorkflow={isLiveWorkflow} /> diff --git a/apps/rowboat/app/projects/[projectId]/playground/components/chat.tsx b/apps/rowboat/app/projects/[projectId]/playground/components/chat.tsx index 89553c06..43ab3502 100644 --- a/apps/rowboat/app/projects/[projectId]/playground/components/chat.tsx +++ b/apps/rowboat/app/projects/[projectId]/playground/components/chat.tsx @@ -1,8 +1,8 @@ 'use client'; import { useEffect, useRef, useState, useCallback } from "react"; -import { getAssistantResponseStreamId } from "@/app/actions/actions"; +import { createCachedTurn, createConversation } from "@/app/actions/playground-chat.actions"; import { Messages } from "./messages"; -import z from "zod"; +import { z } from "zod"; import { Message, ToolMessage } from "@/app/lib/types/types"; import { Workflow } from "@/app/lib/types/workflow_types"; import { ComposeBoxPlayground } from "@/components/common/compose-box-playground"; @@ -11,6 +11,7 @@ import { BillingUpgradeModal } from "@/components/common/billing-upgrade-modal"; import { ChevronDownIcon } from "@heroicons/react/24/outline"; import { FeedbackModal } from "./feedback-modal"; import { FIX_WORKFLOW_PROMPT, FIX_WORKFLOW_PROMPT_WITH_FEEDBACK, EXPLAIN_WORKFLOW_PROMPT_ASSISTANT, EXPLAIN_WORKFLOW_PROMPT_TOOL, EXPLAIN_WORKFLOW_PROMPT_TRANSITION } from "../copilot-prompts"; +import { TurnEvent } from "@/src/entities/models/turn"; export function Chat({ projectId, @@ -20,6 +21,7 @@ export function Chat({ showDebugMessages = true, showJsonMode = false, triggerCopilotChat, + isLiveWorkflow, }: { projectId: string; workflow: z.infer; @@ -28,10 +30,12 @@ export function Chat({ showDebugMessages?: boolean; showJsonMode?: boolean; triggerCopilotChat?: (message: string) => void; + isLiveWorkflow: boolean; }) { + const conversationId = useRef(null); const [messages, setMessages] = useState[]>([]); - const [loadingAssistantResponse, setLoadingAssistantResponse] = useState(false); - const [fetchResponseError, setFetchResponseError] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); const [billingError, setBillingError] = useState(null); const [lastAgenticRequest, setLastAgenticRequest] = useState(null); const [lastAgenticResponse, setLastAgenticResponse] = useState(null); @@ -142,7 +146,7 @@ export function Chat({ if (eventSourceRef.current) { eventSourceRef.current.close(); eventSourceRef.current = null; - setLoadingAssistantResponse(false); + setLoading(false); } }, []); @@ -152,7 +156,7 @@ export function Chat({ content: prompt, }]; setMessages(updatedMessages); - setFetchResponseError(null); + setError(null); setIsLastInteracted(true); } @@ -165,7 +169,7 @@ export function Chat({ } else { setShowUnreadBubble(true); } - }, [optimisticMessages, loadingAssistantResponse, autoScroll]); + }, [optimisticMessages, loading, autoScroll]); // Expose copy function to parent useEffect(() => { @@ -190,148 +194,175 @@ export function Chat({ } }, [messages, messageSubscriber]); - // get assistant response + // get agent response useEffect(() => { let ignore = false; let eventSource: EventSource | null = null; - let msgs: z.infer[] = []; async function process() { - setLoadingAssistantResponse(true); - setFetchResponseError(null); - - // Reset request/response state before making new request - setLastAgenticRequest(null); - setLastAgenticResponse(null); - - let streamId: string | null = null; try { - const response = await getAssistantResponseStreamId( - projectId, - workflow, - messages, - ); + // first, if there is no conversation id, create it + if (!conversationId.current) { + const response = await createConversation({ + projectId, + workflow, + isLiveWorkflow, + }); + conversationId.current = response.id; + } + + // set up a cached turn + const response = await createCachedTurn({ + conversationId: conversationId.current, + messages: messages.slice(-1), // only send the last message + }); if (ignore) { return; } - if ('billingError' in response) { - setBillingError(response.billingError); - setFetchResponseError(response.billingError); - setLoadingAssistantResponse(false); - console.log('returning from getAssistantResponseStreamId due to billing error'); - return; - } - streamId = response.streamId; - } catch (err) { - if (!ignore) { - setFetchResponseError(`Failed to get assistant response: ${err instanceof Error ? err.message : 'Unknown error'}`); - setLoadingAssistantResponse(false); - } - } + // if ('billingError' in response) { + // setBillingError(response.billingError); + // setError(response.billingError); + // setLoading(false); + // console.log('returning from createRun due to billing error'); + // return; + // } - if (ignore || !streamId) { - return; - } + // stream events + eventSource = new EventSource(`/api/stream-response/${response.key}`); + eventSourceRef.current = eventSource; - console.log(`chat.tsx: got streamid: ${streamId}`); - eventSource = new EventSource(`/api/stream-response/${streamId}`); - eventSourceRef.current = eventSource; + // handle events + eventSource.addEventListener("message", (event) => { + console.log(`chat.tsx: got message: ${JSON.stringify(event.data)}`); + if (ignore) { + return; + } - eventSource.addEventListener("message", (event) => { - console.log(`chat.tsx: got message: ${event.data}`); - if (ignore) { - return; - } + try { + const data = JSON.parse(event.data); + const turnEvent = TurnEvent.parse(data); + console.log(`chat.tsx: got event: ${turnEvent}`); - try { - const data = JSON.parse(event.data); - const parsedMsg = Message.parse(data); - msgs.push(parsedMsg); - // Update optimistic messages immediately for real-time streaming UX - setOptimisticMessages(prev => [...prev, parsedMsg]); - } catch (err) { - console.error('Failed to parse SSE message:', err); - setFetchResponseError(`Failed to parse SSE message: ${err instanceof Error ? err.message : 'Unknown error'}`); - // Rollback to last known good state on parsing errors - setOptimisticMessages(messages); - } - }); + switch (turnEvent.type) { + case "message": { + // Handle regular message events + const generatedMessage = turnEvent.data; + // Update optimistic messages immediately for real-time streaming UX + setOptimisticMessages(prev => [...prev, generatedMessage]); + break; + } + case "done": { + // Handle completion event + if (eventSource) { + eventSource.close(); + eventSourceRef.current = null; + } - eventSource.addEventListener('done', (event) => { - console.log(`chat.tsx: got done event: ${event.data}`); - if (eventSource) { - eventSource.close(); - eventSourceRef.current = null; - } + // Combine state and collected messages in the response + setLastAgenticResponse({ + turn: turnEvent.turn, + messages: turnEvent.turn.output, + }); - const parsed = JSON.parse(event.data); + // Commit all streamed messages atomically to the source of truth + setMessages([...messages, ...turnEvent.turn.output]); + setLoading(false); + break; + } + case "error": { + // Handle error event + if (eventSource) { + eventSource.close(); + eventSourceRef.current = null; + } - // Combine state and collected messages in the response - setLastAgenticResponse({ - ...parsed, - messages: msgs + console.error('Turn Error:', turnEvent.error); + if (!ignore) { + setLoading(false); + setError('Error: ' + turnEvent.error); + // Rollback to last known good state on stream errors + setOptimisticMessages(messages); + + // check if billing error + if (turnEvent.isBillingError) { + setBillingError(turnEvent.error); + } + } + break; + } + } + } catch (err) { + console.error('Failed to parse SSE message:', err); + setError(`Failed to parse SSE message: ${err instanceof Error ? err.message : 'Unknown error'}`); + // Rollback to last known good state on parsing errors + setOptimisticMessages(messages); + } }); - // Commit all streamed messages atomically to the source of truth - setMessages([...messages, ...msgs]); - setLoadingAssistantResponse(false); - }); + eventSource.addEventListener('stream_error', (event) => { + console.log(`chat.tsx: got stream_error event: ${event.data}`); + if (eventSource) { + eventSource.close(); + eventSourceRef.current = null; + } + + console.error('SSE Error:', event); + if (!ignore) { + setLoading(false); + setError('Error: ' + JSON.parse(event.data).error); + // Rollback to last known good state on stream errors + setOptimisticMessages(messages); + } + }); - eventSource.addEventListener('stream_error', (event) => { - console.log(`chat.tsx: got stream_error event: ${event.data}`); - if (eventSource) { - eventSource.close(); - eventSourceRef.current = null; - } - - console.error('SSE Error:', event); + eventSource.onerror = (error) => { + console.error('SSE Error:', error); + if (!ignore) { + setLoading(false); + setError('Stream connection failed'); + // Rollback to last known good state on connection errors + setOptimisticMessages(messages); + } + }; + } catch (err) { if (!ignore) { - setLoadingAssistantResponse(false); - setFetchResponseError('Error: ' + JSON.parse(event.data).error); - // Rollback to last known good state on stream errors - setOptimisticMessages(messages); + setError(`Failed to create run: ${err instanceof Error ? err.message : 'Unknown error'}`); + setLoading(false); } - }); - - eventSource.onerror = (error) => { - console.error('SSE Error:', error); - if (!ignore) { - setLoadingAssistantResponse(false); - setFetchResponseError('Stream connection failed'); - // Rollback to last known good state on connection errors - setOptimisticMessages(messages); - } - }; - } - - // if last message is not a user message, return - if (messages.length > 0) { - const last = messages[messages.length - 1]; - if (last.role !== 'user') { - return; } } - // if there is an error, return - if (fetchResponseError) { + // if there are no messages yet, return + if (messages.length === 0) { return; } - console.log(`executing response process: fetchresponseerr: ${fetchResponseError}`); + // if last message is not a user message, return + const last = messages[messages.length - 1]; + if (last.role !== 'user') { + return; + } + + // if there is an error, return + if (error) { + return; + } + + console.log(`chat.tsx: fetching agent response`); + setLoading(true); + setError(null); process(); return () => { ignore = true; - if (eventSource) { - eventSource.close(); - eventSourceRef.current = null; - } }; }, [ + conversationId, messages, projectId, workflow, - fetchResponseError, + isLiveWorkflow, + error, ]); return ( @@ -349,9 +380,17 @@ export function Chat({ > )} - {fetchResponseError && ( + {error && (
-

{fetchResponseError}

+

{error}