mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 08:56:22 +02:00
Clean up tools.py and interactive.py
This commit is contained in:
parent
968dfacd65
commit
770a080232
3 changed files with 135 additions and 170 deletions
|
|
@ -3,9 +3,9 @@ from datetime import datetime
|
|||
import json
|
||||
import sys
|
||||
import asyncio
|
||||
import requests
|
||||
import argparse
|
||||
|
||||
from src.graph.core import order_messages, run_turn_streamed
|
||||
from src.graph.tools import respond_to_tool_raise_error, respond_to_tool_close_chat, RAG_TOOL, CLOSE_CHAT_TOOL
|
||||
from src.utils.common import common_logger, read_json_from_file
|
||||
logger = common_logger
|
||||
|
||||
|
|
@ -26,103 +26,144 @@ def preprocess_messages(messages):
|
|||
msg["role"] = "user"
|
||||
|
||||
return messages
|
||||
|
||||
async def process_turn(messages, agent_configs, tool_configs, prompt_configs, start_agent_name, state, config, complete_request):
|
||||
"""Processes a single turn using streaming API"""
|
||||
print(f"\n{'*'*50}\nLatest Request:\n{'*'*50}")
|
||||
request_json = {
|
||||
"messages": [{k: v for k, v in msg.items() if k != 'current_turn'} for msg in messages],
|
||||
"state": state,
|
||||
"agents": agent_configs,
|
||||
"tools": tool_configs,
|
||||
"prompts": prompt_configs,
|
||||
"startAgent": start_agent_name
|
||||
}
|
||||
print(json.dumps(request_json, indent=2))
|
||||
|
||||
collected_messages = []
|
||||
def stream_chat(host, request_data, api_key):
|
||||
start_time = datetime.now()
|
||||
print("\n" + "="*80)
|
||||
print(f"Starting streaming chat at {start_time}")
|
||||
print(f"Host: {host}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
try:
|
||||
print("\n" + "-"*80)
|
||||
print("Connecting to stream...")
|
||||
stream_response = requests.post(
|
||||
f"{host}/chat_stream",
|
||||
json=request_data,
|
||||
headers={
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
if stream_response.status_code != 200:
|
||||
print(f"Error connecting to stream. Status code: {stream_response.status_code}")
|
||||
print(f"Response: {stream_response.text}")
|
||||
return None, None
|
||||
|
||||
print(f"Successfully connected to stream")
|
||||
print("-"*80 + "\n")
|
||||
|
||||
event_count = 0
|
||||
collected_messages = []
|
||||
final_state = None
|
||||
|
||||
try:
|
||||
print("\n" + "-"*80)
|
||||
print("Starting to process events...")
|
||||
print("-"*80 + "\n")
|
||||
|
||||
for line in stream_response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith('data: '):
|
||||
data = line[6:] # Remove 'data: ' prefix
|
||||
try:
|
||||
event_data = json.loads(data)
|
||||
event_count += 1
|
||||
print("\n" + "*"*80)
|
||||
print(f"Event #{event_count} at {datetime.now().isoformat()}")
|
||||
|
||||
if isinstance(event_data, dict):
|
||||
# Pretty print the event data
|
||||
print("Event Data:")
|
||||
print(json.dumps(event_data, indent=2))
|
||||
|
||||
# Special handling for message events
|
||||
if 'content' in event_data:
|
||||
print("\nMessage Content:", event_data['content'])
|
||||
if event_data.get('tool_calls'):
|
||||
print("Tool Calls:", json.dumps(event_data['tool_calls'], indent=2))
|
||||
|
||||
# Collect messages
|
||||
collected_messages.append(event_data)
|
||||
else:
|
||||
print("Event Data:", event_data)
|
||||
print("*"*80 + "\n")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding event data: {e}")
|
||||
print(f"Raw data: {data}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing stream: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
print("\n" + "-"*80)
|
||||
print(f"Closing stream after processing {event_count} events")
|
||||
print("-"*80 + "\n")
|
||||
stream_response.close()
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Request error during streaming: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = end_time - start_time
|
||||
print("\n" + "="*80)
|
||||
print(f"Streaming session completed at {end_time}")
|
||||
print(f"Total duration: {duration}")
|
||||
print("="*80 + "\n")
|
||||
|
||||
async for event_type, event_data in run_turn_streamed(
|
||||
messages=messages,
|
||||
start_agent_name=start_agent_name,
|
||||
agent_configs=agent_configs,
|
||||
tool_configs=tool_configs,
|
||||
prompt_configs=prompt_configs,
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=state,
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=complete_request
|
||||
):
|
||||
if event_type == "message":
|
||||
# Add each message to collected_messages
|
||||
collected_messages.append(event_data)
|
||||
|
||||
elif event_type == "done":
|
||||
print(f"\n\n{'*'*50}\nLatest Response:\n{'*'*50}")
|
||||
response_json = {
|
||||
"messages": collected_messages,
|
||||
"state": event_data.get('state', {}),
|
||||
}
|
||||
print("Turn completed. Here are the streamed messages and final state:")
|
||||
print(json.dumps(response_json, indent=2))
|
||||
print('='*50)
|
||||
|
||||
return collected_messages, event_data.get('state', {})
|
||||
|
||||
elif event_type == "error":
|
||||
print(f"\nError: {event_data.get('error', 'Unknown error')}")
|
||||
return [], state
|
||||
return collected_messages, final_state
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(f"{'*'*50}Running interactive mode{'*'*50}")
|
||||
|
||||
def extract_request_fields(complete_request):
|
||||
agent_configs = complete_request.get("agents", [])
|
||||
tool_configs = complete_request.get("tools", [])
|
||||
prompt_configs = complete_request.get("prompts", [])
|
||||
start_agent_name = complete_request.get("startAgent", "")
|
||||
|
||||
return agent_configs, tool_configs, prompt_configs, start_agent_name
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, required=False, default='default_config.json',
|
||||
help='Config file name under configs/')
|
||||
parser.add_argument('--sample_request', type=str, required=False, default='default_example.json',
|
||||
help='Sample request JSON file name under tests/sample_requests/')
|
||||
parser.add_argument('--api_key', type=str, required=False, default='test',
|
||||
help='API key to use for authentication')
|
||||
parser.add_argument('--host', type=str, default='http://localhost:4040',
|
||||
help='Host to use for the request')
|
||||
parser.add_argument('--load_messages', action='store_true',
|
||||
help='Load messages from sample request file')
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Config file: {args.config}")
|
||||
print(f"Sample request file: {args.sample_request}")
|
||||
|
||||
external_tool_mappings = {
|
||||
"raise_error": respond_to_tool_raise_error,
|
||||
"close_chat": respond_to_tool_close_chat
|
||||
}
|
||||
config = read_json_from_file(f"./configs/{args.config}")
|
||||
example_request = read_json_from_file(f"./tests/sample_requests/{args.sample_request}").get("lastRequest", {})
|
||||
|
||||
config_file = sys.argv[sys.argv.index("--config") + 1] if "--config" in sys.argv else "default_config.json"
|
||||
sample_request_file = sys.argv[sys.argv.index("--sample_request") + 1] if "--sample_request" in sys.argv else "default_example.json"
|
||||
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Sample request file: {sample_request_file}")
|
||||
|
||||
config = read_json_from_file(f"./configs/{config_file}")
|
||||
example_request = read_json_from_file(f"./tests/sample_requests/{sample_request_file}").get("lastRequest", {})
|
||||
|
||||
if "--load_messages" in sys.argv:
|
||||
if args.load_messages:
|
||||
messages = example_request.get("messages", [])
|
||||
messages = order_messages(messages)
|
||||
user_input_needed = False
|
||||
else:
|
||||
messages = []
|
||||
user_input_needed = True
|
||||
|
||||
turn_start_time = datetime.now()
|
||||
tool_duration = 0
|
||||
|
||||
state = example_request.get("state", {})
|
||||
start_agent_name = example_request.get("startAgent", "")
|
||||
last_agent_name = state.get("last_agent_name", "")
|
||||
if not last_agent_name:
|
||||
last_agent_name = start_agent_name
|
||||
|
||||
|
||||
logger.info("Starting main conversation loop")
|
||||
start_time = None
|
||||
while True:
|
||||
logger.info("Loading configuration files")
|
||||
|
||||
# To account for updates to state
|
||||
complete_request = copy.deepcopy(example_request)
|
||||
agent_configs, tool_configs, prompt_configs, start_agent_name = extract_request_fields(complete_request)
|
||||
complete_request["messages"] = messages
|
||||
complete_request["state"] = state
|
||||
complete_request["startAgent"] = start_agent_name
|
||||
|
||||
print(f"\nUsing agent: {last_agent_name}")
|
||||
|
||||
|
|
@ -132,75 +173,35 @@ if __name__ == "__main__":
|
|||
"role": "user",
|
||||
"content": user_inp
|
||||
})
|
||||
turn_start_time = datetime.now()
|
||||
tool_duration = 0
|
||||
if user_inp == 'exit':
|
||||
logger.info("User requested exit")
|
||||
break
|
||||
logger.info("Added user message to conversation")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# Preprocess messages to replace role tool with role developer and add role user to empty roles
|
||||
print("Preprocessing messages to replace role tool with role developer and add role user to empty roles")
|
||||
# Preprocess messages
|
||||
print("Preprocessing messages")
|
||||
messages = preprocess_messages(messages)
|
||||
complete_request["messages"] = preprocess_messages(complete_request["messages"])
|
||||
|
||||
# Run the streaming turn
|
||||
resp_messages, resp_state = asyncio.run(process_turn(
|
||||
messages=messages,
|
||||
agent_configs=agent_configs,
|
||||
tool_configs=tool_configs,
|
||||
prompt_configs=prompt_configs,
|
||||
start_agent_name=start_agent_name,
|
||||
state=state,
|
||||
config=config,
|
||||
complete_request=complete_request
|
||||
))
|
||||
resp_messages, resp_state = stream_chat(
|
||||
host=args.host,
|
||||
request_data=complete_request,
|
||||
api_key=args.api_key
|
||||
)
|
||||
|
||||
state = resp_state
|
||||
last_msg = resp_messages[-1] if resp_messages else {}
|
||||
tool_calls = last_msg.get("tool_calls", [])
|
||||
sender = last_msg.get("sender", "")
|
||||
|
||||
if config.get("return_diff_messages", True):
|
||||
messages.extend(resp_messages)
|
||||
else:
|
||||
messages = resp_messages
|
||||
if resp_messages:
|
||||
state = resp_state
|
||||
if config.get("return_diff_messages", True):
|
||||
messages.extend(resp_messages)
|
||||
else:
|
||||
messages = resp_messages
|
||||
|
||||
if tool_calls:
|
||||
tool_start_time = datetime.now()
|
||||
user_input_needed = False
|
||||
|
||||
should_break = False
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call["function"]["name"]
|
||||
logger.info(f"Processing tool call: {tool_name}")
|
||||
|
||||
if tool_name not in external_tool_mappings:
|
||||
logger.error(f"Unknown tool call: {tool_name}")
|
||||
raise ValueError(f"Unknown tool call: {tool_name}")
|
||||
|
||||
# Call appropriate handler and process response
|
||||
tool_response = external_tool_mappings[tool_name]([tool_call], mock=True)
|
||||
messages.append(tool_response)
|
||||
logger.info(f"Added {tool_name} response to messages")
|
||||
|
||||
current_tool_duration = round((datetime.now() - tool_start_time).total_seconds() * 10) / 10
|
||||
logger.info(f"Tool response duration: {current_tool_duration:.1f}s")
|
||||
tool_duration += current_tool_duration
|
||||
|
||||
if tool_name == "close_chat":
|
||||
user_input_needed = False
|
||||
logger.info("Closing chat")
|
||||
should_break = True
|
||||
|
||||
if should_break:
|
||||
break
|
||||
|
||||
else:
|
||||
user_input_needed = True
|
||||
print("Quick stats")
|
||||
print(f"Turn Duration: {round((datetime.now() - turn_start_time).total_seconds() * 10) / 10:.1f}s")
|
||||
print(f"Tool Response Duration: {round(tool_duration * 10) / 10:.1f}s")
|
||||
print('='*50)
|
||||
user_input_needed = True
|
||||
print("Quick stats")
|
||||
print(f"Turn Duration: {round((datetime.now() - start_time).total_seconds() * 10) / 10:.1f}s")
|
||||
print('='*50)
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
Loading…
Add table
Add a link
Reference in a new issue