mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-04-26 17:06:23 +02:00
207 lines
No EOL
7.6 KiB
Python
207 lines
No EOL
7.6 KiB
Python
import copy
|
|
from datetime import datetime
|
|
import json
|
|
import sys
|
|
import asyncio
|
|
import requests
|
|
import argparse
|
|
|
|
from src.utils.common import common_logger, read_json_from_file
|
|
logger = common_logger
|
|
|
|
def preprocess_messages(messages):
|
|
# Preprocess messages to handle null content and role issues
|
|
for msg in messages:
|
|
# Handle null content in assistant messages with tool calls
|
|
if (msg.get("role") == "assistant" and
|
|
msg.get("content") is None and
|
|
msg.get("tool_calls") is not None and
|
|
len(msg.get("tool_calls")) > 0):
|
|
msg["content"] = "Calling tool"
|
|
|
|
# Handle role issues
|
|
if msg.get("role") == "tool":
|
|
msg["role"] = "developer"
|
|
elif not msg.get("role"):
|
|
msg["role"] = "user"
|
|
|
|
return messages
|
|
|
|
def stream_chat(host, request_data, api_key):
|
|
start_time = datetime.now()
|
|
print("\n" + "="*80)
|
|
print(f"Starting streaming chat at {start_time}")
|
|
print(f"Host: {host}")
|
|
print("="*80 + "\n")
|
|
|
|
try:
|
|
print("\n" + "-"*80)
|
|
print("Connecting to stream...")
|
|
stream_response = requests.post(
|
|
f"{host}/chat_stream",
|
|
json=request_data,
|
|
headers={
|
|
'Authorization': f'Bearer {api_key}',
|
|
'Accept': 'text/event-stream'
|
|
},
|
|
stream=True
|
|
)
|
|
|
|
if stream_response.status_code != 200:
|
|
print(f"Error connecting to stream. Status code: {stream_response.status_code}")
|
|
print(f"Response: {stream_response.text}")
|
|
return None, None
|
|
|
|
print(f"Successfully connected to stream")
|
|
print("-"*80 + "\n")
|
|
|
|
event_count = 0
|
|
collected_messages = []
|
|
final_state = None
|
|
|
|
try:
|
|
print("\n" + "-"*80)
|
|
print("Starting to process events...")
|
|
print("-"*80 + "\n")
|
|
|
|
for line in stream_response.iter_lines(decode_unicode=True):
|
|
if line:
|
|
if line.startswith('data: '):
|
|
data = line[6:] # Remove 'data: ' prefix
|
|
try:
|
|
event_data = json.loads(data)
|
|
event_count += 1
|
|
print("\n" + "*"*80)
|
|
print(f"Event #{event_count} at {datetime.now().isoformat()}")
|
|
|
|
if isinstance(event_data, dict):
|
|
# Pretty print the event data
|
|
print("Event Data:")
|
|
print(json.dumps(event_data, indent=2))
|
|
|
|
# Special handling for message events
|
|
if 'content' in event_data:
|
|
print("\nMessage Content:", event_data['content'])
|
|
if event_data.get('tool_calls'):
|
|
print("Tool Calls:", json.dumps(event_data['tool_calls'], indent=2))
|
|
|
|
# Collect messages
|
|
collected_messages.append(event_data)
|
|
else:
|
|
print("Event Data:", event_data)
|
|
print("*"*80 + "\n")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error decoding event data: {e}")
|
|
print(f"Raw data: {data}")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing stream: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
print("\n" + "-"*80)
|
|
print(f"Closing stream after processing {event_count} events")
|
|
print("-"*80 + "\n")
|
|
stream_response.close()
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Request error during streaming: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
end_time = datetime.now()
|
|
duration = end_time - start_time
|
|
print("\n" + "="*80)
|
|
print(f"Streaming session completed at {end_time}")
|
|
print(f"Total duration: {duration}")
|
|
print("="*80 + "\n")
|
|
|
|
return collected_messages, final_state
|
|
|
|
if __name__ == "__main__":
|
|
logger.info(f"{'*'*50}Running interactive mode{'*'*50}")
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config', type=str, required=False, default='default_config.json',
|
|
help='Config file name under configs/')
|
|
parser.add_argument('--sample_request', type=str, required=False, default='default_example.json',
|
|
help='Sample request JSON file name under tests/sample_requests/')
|
|
parser.add_argument('--api_key', type=str, required=False, default='test',
|
|
help='API key to use for authentication')
|
|
parser.add_argument('--host', type=str, default='http://localhost:4040',
|
|
help='Host to use for the request')
|
|
parser.add_argument('--load_messages', action='store_true',
|
|
help='Load messages from sample request file')
|
|
args = parser.parse_args()
|
|
|
|
print(f"Config file: {args.config}")
|
|
print(f"Sample request file: {args.sample_request}")
|
|
|
|
config = read_json_from_file(f"./configs/{args.config}")
|
|
example_request = read_json_from_file(f"./tests/sample_requests/{args.sample_request}").get("lastRequest", {})
|
|
|
|
if args.load_messages:
|
|
messages = example_request.get("messages", [])
|
|
user_input_needed = False
|
|
else:
|
|
messages = []
|
|
user_input_needed = True
|
|
|
|
state = example_request.get("state", {})
|
|
start_agent_name = example_request.get("startAgent", "")
|
|
last_agent_name = state.get("last_agent_name", "")
|
|
if not last_agent_name:
|
|
last_agent_name = start_agent_name
|
|
|
|
logger.info("Starting main conversation loop")
|
|
start_time = None
|
|
while True:
|
|
logger.info("Loading configuration files")
|
|
|
|
# To account for updates to state
|
|
complete_request = copy.deepcopy(example_request)
|
|
complete_request["messages"] = messages
|
|
complete_request["state"] = state
|
|
complete_request["startAgent"] = start_agent_name
|
|
|
|
print(f"\nUsing agent: {last_agent_name}")
|
|
|
|
if user_input_needed:
|
|
user_inp = input('\nUSER: ')
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_inp
|
|
})
|
|
if user_inp == 'exit':
|
|
logger.info("User requested exit")
|
|
break
|
|
logger.info("Added user message to conversation")
|
|
|
|
start_time = datetime.now()
|
|
|
|
# Preprocess messages
|
|
print("Preprocessing messages")
|
|
messages = preprocess_messages(messages)
|
|
complete_request["messages"] = preprocess_messages(complete_request["messages"])
|
|
|
|
# Run the streaming turn
|
|
resp_messages, resp_state = stream_chat(
|
|
host=args.host,
|
|
request_data=complete_request,
|
|
api_key=args.api_key
|
|
)
|
|
|
|
if resp_messages:
|
|
state = resp_state
|
|
if config.get("return_diff_messages", True):
|
|
messages.extend(resp_messages)
|
|
else:
|
|
messages = resp_messages
|
|
|
|
user_input_needed = True
|
|
print("Quick stats")
|
|
print(f"Turn Duration: {round((datetime.now() - start_time).total_seconds() * 10) / 10:.1f}s")
|
|
print('='*50)
|
|
|
|
print("\n" + "-" * 80) |