Clean up tools.py and interactive.py

This commit is contained in:
akhisud3195 2025-05-09 10:17:36 +05:30
parent 968dfacd65
commit 770a080232
3 changed files with 135 additions and 170 deletions

View file

@ -3,9 +3,9 @@ from datetime import datetime
import json
import sys
import asyncio
import requests
import argparse
from src.graph.core import order_messages, run_turn_streamed
from src.graph.tools import respond_to_tool_raise_error, respond_to_tool_close_chat, RAG_TOOL, CLOSE_CHAT_TOOL
from src.utils.common import common_logger, read_json_from_file
logger = common_logger
@ -26,103 +26,144 @@ def preprocess_messages(messages):
msg["role"] = "user"
return messages
async def process_turn(messages, agent_configs, tool_configs, prompt_configs, start_agent_name, state, config, complete_request):
"""Processes a single turn using streaming API"""
print(f"\n{'*'*50}\nLatest Request:\n{'*'*50}")
request_json = {
"messages": [{k: v for k, v in msg.items() if k != 'current_turn'} for msg in messages],
"state": state,
"agents": agent_configs,
"tools": tool_configs,
"prompts": prompt_configs,
"startAgent": start_agent_name
}
print(json.dumps(request_json, indent=2))
collected_messages = []
def stream_chat(host, request_data, api_key):
start_time = datetime.now()
print("\n" + "="*80)
print(f"Starting streaming chat at {start_time}")
print(f"Host: {host}")
print("="*80 + "\n")
try:
print("\n" + "-"*80)
print("Connecting to stream...")
stream_response = requests.post(
f"{host}/chat_stream",
json=request_data,
headers={
'Authorization': f'Bearer {api_key}',
'Accept': 'text/event-stream'
},
stream=True
)
if stream_response.status_code != 200:
print(f"Error connecting to stream. Status code: {stream_response.status_code}")
print(f"Response: {stream_response.text}")
return None, None
print(f"Successfully connected to stream")
print("-"*80 + "\n")
event_count = 0
collected_messages = []
final_state = None
try:
print("\n" + "-"*80)
print("Starting to process events...")
print("-"*80 + "\n")
for line in stream_response.iter_lines(decode_unicode=True):
if line:
if line.startswith('data: '):
data = line[6:] # Remove 'data: ' prefix
try:
event_data = json.loads(data)
event_count += 1
print("\n" + "*"*80)
print(f"Event #{event_count} at {datetime.now().isoformat()}")
if isinstance(event_data, dict):
# Pretty print the event data
print("Event Data:")
print(json.dumps(event_data, indent=2))
# Special handling for message events
if 'content' in event_data:
print("\nMessage Content:", event_data['content'])
if event_data.get('tool_calls'):
print("Tool Calls:", json.dumps(event_data['tool_calls'], indent=2))
# Collect messages
collected_messages.append(event_data)
else:
print("Event Data:", event_data)
print("*"*80 + "\n")
except json.JSONDecodeError as e:
print(f"Error decoding event data: {e}")
print(f"Raw data: {data}")
except Exception as e:
print(f"Error processing stream: {e}")
import traceback
traceback.print_exc()
finally:
print("\n" + "-"*80)
print(f"Closing stream after processing {event_count} events")
print("-"*80 + "\n")
stream_response.close()
except requests.exceptions.RequestException as e:
print(f"Request error during streaming: {e}")
import traceback
traceback.print_exc()
end_time = datetime.now()
duration = end_time - start_time
print("\n" + "="*80)
print(f"Streaming session completed at {end_time}")
print(f"Total duration: {duration}")
print("="*80 + "\n")
async for event_type, event_data in run_turn_streamed(
messages=messages,
start_agent_name=start_agent_name,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
state=state,
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
complete_request=complete_request
):
if event_type == "message":
# Add each message to collected_messages
collected_messages.append(event_data)
elif event_type == "done":
print(f"\n\n{'*'*50}\nLatest Response:\n{'*'*50}")
response_json = {
"messages": collected_messages,
"state": event_data.get('state', {}),
}
print("Turn completed. Here are the streamed messages and final state:")
print(json.dumps(response_json, indent=2))
print('='*50)
return collected_messages, event_data.get('state', {})
elif event_type == "error":
print(f"\nError: {event_data.get('error', 'Unknown error')}")
return [], state
return collected_messages, final_state
if __name__ == "__main__":
logger.info(f"{'*'*50}Running interactive mode{'*'*50}")
def extract_request_fields(complete_request):
agent_configs = complete_request.get("agents", [])
tool_configs = complete_request.get("tools", [])
prompt_configs = complete_request.get("prompts", [])
start_agent_name = complete_request.get("startAgent", "")
return agent_configs, tool_configs, prompt_configs, start_agent_name
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=False, default='default_config.json',
help='Config file name under configs/')
parser.add_argument('--sample_request', type=str, required=False, default='default_example.json',
help='Sample request JSON file name under tests/sample_requests/')
parser.add_argument('--api_key', type=str, required=False, default='test',
help='API key to use for authentication')
parser.add_argument('--host', type=str, default='http://localhost:4040',
help='Host to use for the request')
parser.add_argument('--load_messages', action='store_true',
help='Load messages from sample request file')
args = parser.parse_args()
print(f"Config file: {args.config}")
print(f"Sample request file: {args.sample_request}")
external_tool_mappings = {
"raise_error": respond_to_tool_raise_error,
"close_chat": respond_to_tool_close_chat
}
config = read_json_from_file(f"./configs/{args.config}")
example_request = read_json_from_file(f"./tests/sample_requests/{args.sample_request}").get("lastRequest", {})
config_file = sys.argv[sys.argv.index("--config") + 1] if "--config" in sys.argv else "default_config.json"
sample_request_file = sys.argv[sys.argv.index("--sample_request") + 1] if "--sample_request" in sys.argv else "default_example.json"
print(f"Config file: {config_file}")
print(f"Sample request file: {sample_request_file}")
config = read_json_from_file(f"./configs/{config_file}")
example_request = read_json_from_file(f"./tests/sample_requests/{sample_request_file}").get("lastRequest", {})
if "--load_messages" in sys.argv:
if args.load_messages:
messages = example_request.get("messages", [])
messages = order_messages(messages)
user_input_needed = False
else:
messages = []
user_input_needed = True
turn_start_time = datetime.now()
tool_duration = 0
state = example_request.get("state", {})
start_agent_name = example_request.get("startAgent", "")
last_agent_name = state.get("last_agent_name", "")
if not last_agent_name:
last_agent_name = start_agent_name
logger.info("Starting main conversation loop")
start_time = None
while True:
logger.info("Loading configuration files")
# To account for updates to state
complete_request = copy.deepcopy(example_request)
agent_configs, tool_configs, prompt_configs, start_agent_name = extract_request_fields(complete_request)
complete_request["messages"] = messages
complete_request["state"] = state
complete_request["startAgent"] = start_agent_name
print(f"\nUsing agent: {last_agent_name}")
@ -132,75 +173,35 @@ if __name__ == "__main__":
"role": "user",
"content": user_inp
})
turn_start_time = datetime.now()
tool_duration = 0
if user_inp == 'exit':
logger.info("User requested exit")
break
logger.info("Added user message to conversation")
start_time = datetime.now()
# Preprocess messages to replace role tool with role developer and add role user to empty roles
print("Preprocessing messages to replace role tool with role developer and add role user to empty roles")
# Preprocess messages
print("Preprocessing messages")
messages = preprocess_messages(messages)
complete_request["messages"] = preprocess_messages(complete_request["messages"])
# Run the streaming turn
resp_messages, resp_state = asyncio.run(process_turn(
messages=messages,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
start_agent_name=start_agent_name,
state=state,
config=config,
complete_request=complete_request
))
resp_messages, resp_state = stream_chat(
host=args.host,
request_data=complete_request,
api_key=args.api_key
)
state = resp_state
last_msg = resp_messages[-1] if resp_messages else {}
tool_calls = last_msg.get("tool_calls", [])
sender = last_msg.get("sender", "")
if config.get("return_diff_messages", True):
messages.extend(resp_messages)
else:
messages = resp_messages
if resp_messages:
state = resp_state
if config.get("return_diff_messages", True):
messages.extend(resp_messages)
else:
messages = resp_messages
if tool_calls:
tool_start_time = datetime.now()
user_input_needed = False
should_break = False
for tool_call in tool_calls:
tool_name = tool_call["function"]["name"]
logger.info(f"Processing tool call: {tool_name}")
if tool_name not in external_tool_mappings:
logger.error(f"Unknown tool call: {tool_name}")
raise ValueError(f"Unknown tool call: {tool_name}")
# Call appropriate handler and process response
tool_response = external_tool_mappings[tool_name]([tool_call], mock=True)
messages.append(tool_response)
logger.info(f"Added {tool_name} response to messages")
current_tool_duration = round((datetime.now() - tool_start_time).total_seconds() * 10) / 10
logger.info(f"Tool response duration: {current_tool_duration:.1f}s")
tool_duration += current_tool_duration
if tool_name == "close_chat":
user_input_needed = False
logger.info("Closing chat")
should_break = True
if should_break:
break
else:
user_input_needed = True
print("Quick stats")
print(f"Turn Duration: {round((datetime.now() - turn_start_time).total_seconds() * 10) / 10:.1f}s")
print(f"Tool Response Duration: {round(tool_duration * 10) / 10:.1f}s")
print('='*50)
user_input_needed = True
print("Quick stats")
print(f"Turn Duration: {round((datetime.now() - start_time).total_seconds() * 10) / 10:.1f}s")
print('='*50)
print("\n" + "-" * 80)