diff --git a/apps/rowboat_agents/src/graph/execute_turn.py b/apps/rowboat_agents/src/graph/execute_turn.py index fd332315..1fa43232 100644 --- a/apps/rowboat_agents/src/graph/execute_turn.py +++ b/apps/rowboat_agents/src/graph/execute_turn.py @@ -235,7 +235,6 @@ def get_agents(agent_configs, tool_configs, complete_request): ) if tool: new_tools.append(tool) - logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}") print(f"Added tool {tool_name} to agent {agent_config['name']}") else: print(f"WARNING: Tool {tool_name} not found in tool_configs") diff --git a/apps/rowboat_agents/src/graph/tools.py b/apps/rowboat_agents/src/graph/tools.py index d8bd5ae0..425282c0 100644 --- a/apps/rowboat_agents/src/graph/tools.py +++ b/apps/rowboat_agents/src/graph/tools.py @@ -1,43 +1,8 @@ import json import random -from src.utils.common import common_logger -logger = common_logger - -RAG_TOOL = { - "name": "getArticleInfo", - "type": "rag", - "description": "Fetch articles with knowledge relevant to the query", - "parameters": { - "type": "object", - "properties": { - "question": { - "type": "string", - "description": "The query to retrieve articles for" - } - }, - "required": [ - "query" - ] - } -} - -CLOSE_CHAT_TOOL = { - "name": "close_chat", - "type": "close_chat", - "description": "Close the chat", - "parameters": { - "type": "object", - "properties": { - "error_message": { - "type": "string", "description": "The error message to close the chat with" - } - } - } -} - def tool_raise_error(error_message): - logger.error(f"Raising error: {error_message}") + print(f"Raising error: {error_message}") raise ValueError(f"Raising error: {error_message}") def respond_to_tool_raise_error(tool_calls, mock=False): @@ -45,7 +10,7 @@ def respond_to_tool_raise_error(tool_calls, mock=False): return _create_tool_response(tool_calls, tool_raise_error(error_message)) def tool_close_chat(error_message): - logger.error(f"Closing chat: {error_message}") + print(f"Closing chat: {error_message}") raise ValueError(f"Closing chat: {error_message}") def respond_to_tool_close_chat(tool_calls, mock=False): diff --git a/apps/rowboat_agents/tests/interactive.py b/apps/rowboat_agents/tests/interactive.py index ea0ec816..f1c72fce 100644 --- a/apps/rowboat_agents/tests/interactive.py +++ b/apps/rowboat_agents/tests/interactive.py @@ -3,9 +3,9 @@ from datetime import datetime import json import sys import asyncio +import requests +import argparse -from src.graph.core import order_messages, run_turn_streamed -from src.graph.tools import respond_to_tool_raise_error, respond_to_tool_close_chat, RAG_TOOL, CLOSE_CHAT_TOOL from src.utils.common import common_logger, read_json_from_file logger = common_logger @@ -26,103 +26,144 @@ def preprocess_messages(messages): msg["role"] = "user" return messages - -async def process_turn(messages, agent_configs, tool_configs, prompt_configs, start_agent_name, state, config, complete_request): - """Processes a single turn using streaming API""" - print(f"\n{'*'*50}\nLatest Request:\n{'*'*50}") - request_json = { - "messages": [{k: v for k, v in msg.items() if k != 'current_turn'} for msg in messages], - "state": state, - "agents": agent_configs, - "tools": tool_configs, - "prompts": prompt_configs, - "startAgent": start_agent_name - } - print(json.dumps(request_json, indent=2)) - collected_messages = [] +def stream_chat(host, request_data, api_key): + start_time = datetime.now() + print("\n" + "="*80) + print(f"Starting streaming chat at {start_time}") + print(f"Host: {host}") + print("="*80 + "\n") + + try: + print("\n" + "-"*80) + print("Connecting to stream...") + stream_response = requests.post( + f"{host}/chat_stream", + json=request_data, + headers={ + 'Authorization': f'Bearer {api_key}', + 'Accept': 'text/event-stream' + }, + stream=True + ) + + if stream_response.status_code != 200: + print(f"Error connecting to stream. Status code: {stream_response.status_code}") + print(f"Response: {stream_response.text}") + return None, None + + print(f"Successfully connected to stream") + print("-"*80 + "\n") + + event_count = 0 + collected_messages = [] + final_state = None + + try: + print("\n" + "-"*80) + print("Starting to process events...") + print("-"*80 + "\n") + + for line in stream_response.iter_lines(decode_unicode=True): + if line: + if line.startswith('data: '): + data = line[6:] # Remove 'data: ' prefix + try: + event_data = json.loads(data) + event_count += 1 + print("\n" + "*"*80) + print(f"Event #{event_count} at {datetime.now().isoformat()}") + + if isinstance(event_data, dict): + # Pretty print the event data + print("Event Data:") + print(json.dumps(event_data, indent=2)) + + # Special handling for message events + if 'content' in event_data: + print("\nMessage Content:", event_data['content']) + if event_data.get('tool_calls'): + print("Tool Calls:", json.dumps(event_data['tool_calls'], indent=2)) + + # Collect messages + collected_messages.append(event_data) + else: + print("Event Data:", event_data) + print("*"*80 + "\n") + + except json.JSONDecodeError as e: + print(f"Error decoding event data: {e}") + print(f"Raw data: {data}") + + except Exception as e: + print(f"Error processing stream: {e}") + import traceback + traceback.print_exc() + finally: + print("\n" + "-"*80) + print(f"Closing stream after processing {event_count} events") + print("-"*80 + "\n") + stream_response.close() + + except requests.exceptions.RequestException as e: + print(f"Request error during streaming: {e}") + import traceback + traceback.print_exc() + + end_time = datetime.now() + duration = end_time - start_time + print("\n" + "="*80) + print(f"Streaming session completed at {end_time}") + print(f"Total duration: {duration}") + print("="*80 + "\n") - async for event_type, event_data in run_turn_streamed( - messages=messages, - start_agent_name=start_agent_name, - agent_configs=agent_configs, - tool_configs=tool_configs, - prompt_configs=prompt_configs, - start_turn_with_start_agent=config.get("start_turn_with_start_agent", False), - state=state, - additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL], - complete_request=complete_request - ): - if event_type == "message": - # Add each message to collected_messages - collected_messages.append(event_data) - - elif event_type == "done": - print(f"\n\n{'*'*50}\nLatest Response:\n{'*'*50}") - response_json = { - "messages": collected_messages, - "state": event_data.get('state', {}), - } - print("Turn completed. Here are the streamed messages and final state:") - print(json.dumps(response_json, indent=2)) - print('='*50) - - return collected_messages, event_data.get('state', {}) - - elif event_type == "error": - print(f"\nError: {event_data.get('error', 'Unknown error')}") - return [], state + return collected_messages, final_state if __name__ == "__main__": logger.info(f"{'*'*50}Running interactive mode{'*'*50}") - def extract_request_fields(complete_request): - agent_configs = complete_request.get("agents", []) - tool_configs = complete_request.get("tools", []) - prompt_configs = complete_request.get("prompts", []) - start_agent_name = complete_request.get("startAgent", "") - - return agent_configs, tool_configs, prompt_configs, start_agent_name + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=False, default='default_config.json', + help='Config file name under configs/') + parser.add_argument('--sample_request', type=str, required=False, default='default_example.json', + help='Sample request JSON file name under tests/sample_requests/') + parser.add_argument('--api_key', type=str, required=False, default='test', + help='API key to use for authentication') + parser.add_argument('--host', type=str, default='http://localhost:4040', + help='Host to use for the request') + parser.add_argument('--load_messages', action='store_true', + help='Load messages from sample request file') + args = parser.parse_args() + + print(f"Config file: {args.config}") + print(f"Sample request file: {args.sample_request}") - external_tool_mappings = { - "raise_error": respond_to_tool_raise_error, - "close_chat": respond_to_tool_close_chat - } + config = read_json_from_file(f"./configs/{args.config}") + example_request = read_json_from_file(f"./tests/sample_requests/{args.sample_request}").get("lastRequest", {}) - config_file = sys.argv[sys.argv.index("--config") + 1] if "--config" in sys.argv else "default_config.json" - sample_request_file = sys.argv[sys.argv.index("--sample_request") + 1] if "--sample_request" in sys.argv else "default_example.json" - - print(f"Config file: {config_file}") - print(f"Sample request file: {sample_request_file}") - - config = read_json_from_file(f"./configs/{config_file}") - example_request = read_json_from_file(f"./tests/sample_requests/{sample_request_file}").get("lastRequest", {}) - - if "--load_messages" in sys.argv: + if args.load_messages: messages = example_request.get("messages", []) - messages = order_messages(messages) user_input_needed = False else: messages = [] user_input_needed = True - turn_start_time = datetime.now() - tool_duration = 0 - state = example_request.get("state", {}) start_agent_name = example_request.get("startAgent", "") last_agent_name = state.get("last_agent_name", "") if not last_agent_name: last_agent_name = start_agent_name - logger.info("Starting main conversation loop") + start_time = None while True: logger.info("Loading configuration files") # To account for updates to state complete_request = copy.deepcopy(example_request) - agent_configs, tool_configs, prompt_configs, start_agent_name = extract_request_fields(complete_request) + complete_request["messages"] = messages + complete_request["state"] = state + complete_request["startAgent"] = start_agent_name print(f"\nUsing agent: {last_agent_name}") @@ -132,75 +173,35 @@ if __name__ == "__main__": "role": "user", "content": user_inp }) - turn_start_time = datetime.now() - tool_duration = 0 if user_inp == 'exit': logger.info("User requested exit") break logger.info("Added user message to conversation") + + start_time = datetime.now() - # Preprocess messages to replace role tool with role developer and add role user to empty roles - print("Preprocessing messages to replace role tool with role developer and add role user to empty roles") + # Preprocess messages + print("Preprocessing messages") messages = preprocess_messages(messages) complete_request["messages"] = preprocess_messages(complete_request["messages"]) # Run the streaming turn - resp_messages, resp_state = asyncio.run(process_turn( - messages=messages, - agent_configs=agent_configs, - tool_configs=tool_configs, - prompt_configs=prompt_configs, - start_agent_name=start_agent_name, - state=state, - config=config, - complete_request=complete_request - )) + resp_messages, resp_state = stream_chat( + host=args.host, + request_data=complete_request, + api_key=args.api_key + ) - state = resp_state - last_msg = resp_messages[-1] if resp_messages else {} - tool_calls = last_msg.get("tool_calls", []) - sender = last_msg.get("sender", "") - - if config.get("return_diff_messages", True): - messages.extend(resp_messages) - else: - messages = resp_messages + if resp_messages: + state = resp_state + if config.get("return_diff_messages", True): + messages.extend(resp_messages) + else: + messages = resp_messages - if tool_calls: - tool_start_time = datetime.now() - user_input_needed = False - - should_break = False - for tool_call in tool_calls: - tool_name = tool_call["function"]["name"] - logger.info(f"Processing tool call: {tool_name}") - - if tool_name not in external_tool_mappings: - logger.error(f"Unknown tool call: {tool_name}") - raise ValueError(f"Unknown tool call: {tool_name}") - - # Call appropriate handler and process response - tool_response = external_tool_mappings[tool_name]([tool_call], mock=True) - messages.append(tool_response) - logger.info(f"Added {tool_name} response to messages") - - current_tool_duration = round((datetime.now() - tool_start_time).total_seconds() * 10) / 10 - logger.info(f"Tool response duration: {current_tool_duration:.1f}s") - tool_duration += current_tool_duration - - if tool_name == "close_chat": - user_input_needed = False - logger.info("Closing chat") - should_break = True - - if should_break: - break - - else: - user_input_needed = True - print("Quick stats") - print(f"Turn Duration: {round((datetime.now() - turn_start_time).total_seconds() * 10) / 10:.1f}s") - print(f"Tool Response Duration: {round(tool_duration * 10) / 10:.1f}s") - print('='*50) + user_input_needed = True + print("Quick stats") + print(f"Turn Duration: {round((datetime.now() - start_time).total_seconds() * 10) / 10:.1f}s") + print('='*50) print("\n" + "-" * 80) \ No newline at end of file