rowboat/apps/rowboat_agents/tests/interactive.py
2025-03-26 15:51:59 +05:30

205 lines
No EOL
7.9 KiB
Python

import copy
from datetime import datetime
import json
import sys
import asyncio
from src.graph.core import order_messages, run_turn_streamed
from src.graph.tools import respond_to_tool_raise_error, respond_to_tool_close_chat, RAG_TOOL, CLOSE_CHAT_TOOL
from src.utils.common import common_logger, read_json_from_file
logger = common_logger
def preprocess_messages(messages):
# Preprocess messages to handle null content and role issues
for msg in messages:
# Handle null content in assistant messages with tool calls
if (msg.get("role") == "assistant" and
msg.get("content") is None and
msg.get("tool_calls") is not None and
len(msg.get("tool_calls")) > 0):
msg["content"] = "Calling tool"
# Handle role issues
if msg.get("role") == "tool":
msg["role"] = "developer"
elif not msg.get("role"):
msg["role"] = "user"
return messages
async def process_turn(messages, agent_configs, tool_configs, prompt_configs, start_agent_name, state, config, complete_request):
"""Processes a single turn using streaming API"""
print(f"\n{'*'*50}\nLatest Request:\n{'*'*50}")
request_json = {
"messages": [{k: v for k, v in msg.items() if k != 'current_turn'} for msg in messages],
"state": state,
"agents": agent_configs,
"tools": tool_configs,
"prompts": prompt_configs,
"startAgent": start_agent_name
}
print(json.dumps(request_json, indent=2))
collected_messages = []
async for event_type, event_data in run_turn_streamed(
messages=messages,
start_agent_name=start_agent_name,
agent_configs=agent_configs,
tool_configs=tool_configs,
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=state,
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=complete_request
):
if event_type == "message":
# Add each message to collected_messages
collected_messages.append(event_data)
elif event_type == "done":
print(f"\n\n{'*'*50}\nLatest Response:\n{'*'*50}")
response_json = {
"messages": collected_messages,
"state": event_data.get('state', {}),
}
print("Turn completed. Here are the streamed messages and final state:")
print(json.dumps(response_json, indent=2))
print('='*50)
return collected_messages, event_data.get('state', {})
elif event_type == "error":
print(f"\nError: {event_data.get('error', 'Unknown error')}")
return [], state
if __name__ == "__main__":
logger.info(f"{'*'*50}Running interactive mode{'*'*50}")
def extract_request_fields(complete_request):
agent_configs = complete_request.get("agents", [])
tool_configs = complete_request.get("tools", [])
prompt_configs = complete_request.get("prompts", [])
start_agent_name = complete_request.get("startAgent", "")
return agent_configs, tool_configs, prompt_configs, start_agent_name
external_tool_mappings = {
"raise_error": respond_to_tool_raise_error,
"close_chat": respond_to_tool_close_chat
}
config_file = sys.argv[sys.argv.index("--config") + 1] if "--config" in sys.argv else "default_config.json"
sample_request_file = sys.argv[sys.argv.index("--sample_request") + 1] if "--sample_request" in sys.argv else "default_example.json"
print(f"Config file: {config_file}")
print(f"Sample request file: {sample_request_file}")
config = read_json_from_file(f"./configs/{config_file}")
example_request = read_json_from_file(f"./tests/sample_requests/{sample_request_file}").get("lastRequest", {})
if "--load_messages" in sys.argv:
messages = example_request.get("messages", [])
messages = order_messages(messages)
user_input_needed = False
else:
messages = []
user_input_needed = True
turn_start_time = datetime.now()
tool_duration = 0
state = example_request.get("state", {})
start_agent_name = example_request.get("startAgent", "")
last_agent_name = state.get("last_agent_name", "")
if not last_agent_name:
last_agent_name = start_agent_name
logger.info("Starting main conversation loop")
while True:
logger.info("Loading configuration files")
# To account for updates to state
complete_request = copy.deepcopy(example_request)
agent_configs, tool_configs, prompt_configs, start_agent_name = extract_request_fields(complete_request)
print(f"\nUsing agent: {last_agent_name}")
if user_input_needed:
user_inp = input('\nUSER: ')
messages.append({
"role": "user",
"content": user_inp
})
turn_start_time = datetime.now()
tool_duration = 0
if user_inp == 'exit':
logger.info("User requested exit")
break
logger.info("Added user message to conversation")
# Preprocess messages to replace role tool with role developer and add role user to empty roles
print("Preprocessing messages to replace role tool with role developer and add role user to empty roles")
messages = preprocess_messages(messages)
complete_request["messages"] = preprocess_messages(complete_request["messages"])
# Run the streaming turn
resp_messages, resp_state = asyncio.run(process_turn(
messages=messages,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
start_agent_name=start_agent_name,
state=state,
config=config,
complete_request=complete_request
))
state = resp_state
last_msg = resp_messages[-1] if resp_messages else {}
tool_calls = last_msg.get("tool_calls", [])
sender = last_msg.get("sender", "")
if config.get("return_diff_messages", True):
messages.extend(resp_messages)
else:
messages = resp_messages
if tool_calls:
tool_start_time = datetime.now()
user_input_needed = False
should_break = False
for tool_call in tool_calls:
tool_name = tool_call["function"]["name"]
logger.info(f"Processing tool call: {tool_name}")
if tool_name not in external_tool_mappings:
logger.error(f"Unknown tool call: {tool_name}")
raise ValueError(f"Unknown tool call: {tool_name}")
# Call appropriate handler and process response
tool_response = external_tool_mappings[tool_name]([tool_call], mock=True)
messages.append(tool_response)
logger.info(f"Added {tool_name} response to messages")
current_tool_duration = round((datetime.now() - tool_start_time).total_seconds() * 10) / 10
logger.info(f"Tool response duration: {current_tool_duration:.1f}s")
tool_duration += current_tool_duration
if tool_name == "close_chat":
user_input_needed = False
logger.info("Closing chat")
should_break = True
if should_break:
break
else:
user_input_needed = True
print("Quick stats")
print(f"Turn Duration: {round((datetime.now() - turn_start_time).total_seconds() * 10) / 10:.1f}s")
print(f"Tool Response Duration: {round(tool_duration * 10) / 10:.1f}s")
print('='*50)
print("\n" + "-" * 80)