mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-05-01 03:16:29 +02:00
Merge changes v1
This commit is contained in:
parent
b2fd9bf877
commit
24efe0e887
45 changed files with 2940 additions and 294 deletions
0
apps/rowboat_agents/src/__init__.py
Normal file
0
apps/rowboat_agents/src/__init__.py
Normal file
0
apps/rowboat_agents/src/app/__init__.py
Normal file
0
apps/rowboat_agents/src/app/__init__.py
Normal file
200
apps/rowboat_agents/src/app/main.py
Normal file
200
apps/rowboat_agents/src/app/main.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
from flask import Flask, request, jsonify, Response
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import os
|
||||
import redis
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.asyncio import serve
|
||||
|
||||
from src.graph.core import run_turn, run_turn_streamed
|
||||
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
|
||||
from src.utils.common import common_logger, read_json_from_file
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
logger = common_logger
|
||||
redis_client = redis.from_url(os.environ.get('REDIS_URL', 'redis://localhost:6379'))
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
@app.route("/")
|
||||
def home():
|
||||
return "Hello, World!"
|
||||
|
||||
def require_api_key(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if not auth_header or not auth_header.startswith('Bearer '):
|
||||
return jsonify({'error': 'Missing or invalid authorization header'}), 401
|
||||
|
||||
token = auth_header.split('Bearer ')[1]
|
||||
actual = os.environ.get('API_KEY', '').strip()
|
||||
if actual and token != actual:
|
||||
return jsonify({'error': 'Invalid API key'}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return decorated
|
||||
|
||||
@app.route("/chat", methods=["POST"])
|
||||
@require_api_key
|
||||
def chat():
|
||||
logger.info('='*100)
|
||||
logger.info(f"{'*'*100}Running server mode{'*'*100}")
|
||||
try:
|
||||
data = request.get_json()
|
||||
logger.info('Complete request:')
|
||||
logger.info(data)
|
||||
logger.info('-'*100)
|
||||
|
||||
start_time = datetime.now()
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
logger.info('Beginning turn')
|
||||
resp_messages, resp_tokens_used, resp_state = run_turn(
|
||||
messages=data.get("messages", []),
|
||||
start_agent_name=data.get("startAgent", ""),
|
||||
agent_configs=data.get("agents", []),
|
||||
tool_configs=data.get("tools", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=data
|
||||
)
|
||||
|
||||
logger.info('-'*100)
|
||||
logger.info('Raw output:')
|
||||
logger.info((resp_messages, resp_tokens_used, resp_state))
|
||||
|
||||
out = {
|
||||
"messages": resp_messages,
|
||||
"tokens_used": resp_tokens_used,
|
||||
"state": resp_state,
|
||||
}
|
||||
|
||||
logger.info("Output:")
|
||||
for k, v in out.items():
|
||||
logger.info(f"{k}: {v}")
|
||||
logger.info('*'*100)
|
||||
|
||||
logger.info('='*100)
|
||||
logger.info(f"Processing time: {datetime.now() - start_time}")
|
||||
|
||||
return jsonify(out)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route("/chat_stream_init", methods=["POST"])
|
||||
@require_api_key
|
||||
def chat_stream_init():
|
||||
# create a uuid for the stream
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
# store the request data in redis with 10 minute TTL
|
||||
data = request.get_json()
|
||||
redis_client.setex(f"stream_request_{stream_id}", 600, json.dumps(data))
|
||||
|
||||
return jsonify({"stream_id": stream_id})
|
||||
|
||||
@app.route("/chat_stream/<stream_id>", methods=["GET"])
|
||||
@require_api_key
|
||||
def chat_stream(stream_id):
|
||||
# get the request data from redis
|
||||
request_data = redis_client.get(f"stream_request_{stream_id}")
|
||||
if not request_data:
|
||||
return jsonify({"error": "Stream not found"}), 404
|
||||
|
||||
request_data = json.loads(request_data)
|
||||
config = read_json_from_file("./configs/default_config.json")
|
||||
|
||||
# Preprocess messages to handle null content and role issues
|
||||
for msg in request_data["messages"]:
|
||||
# Handle null content in assistant messages with tool calls
|
||||
if (msg.get("role") == "assistant" and
|
||||
msg.get("content") is None and
|
||||
msg.get("tool_calls") is not None and
|
||||
len(msg.get("tool_calls")) > 0):
|
||||
msg["content"] = "Calling tool"
|
||||
|
||||
# Handle role issues
|
||||
if msg.get("role") == "tool":
|
||||
msg["role"] = "developer"
|
||||
elif not msg.get("role"):
|
||||
msg["role"] = "user"
|
||||
|
||||
print('*'*200)
|
||||
print("Request:")
|
||||
print('*'*200)
|
||||
pprint(request_data)
|
||||
print('='*200)
|
||||
|
||||
|
||||
async def process_stream():
|
||||
try:
|
||||
async for event_type, event_data in run_turn_streamed(
|
||||
messages=request_data.get("messages", []),
|
||||
start_agent_name=request_data.get("startAgent", ""),
|
||||
agent_configs=request_data.get("agents", []),
|
||||
tool_configs=request_data.get("tools", []),
|
||||
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
|
||||
state=request_data.get("state", {}),
|
||||
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
|
||||
complete_request=request_data
|
||||
):
|
||||
if event_type == 'message':
|
||||
print('*'*200)
|
||||
print("Yielding message:")
|
||||
print('*'*200)
|
||||
to_yield = f"event: message\ndata: {json.dumps(event_data)}\n\n"
|
||||
print(to_yield)
|
||||
print('='*200)
|
||||
yield to_yield
|
||||
elif event_type == 'done':
|
||||
print('*'*200)
|
||||
print("Yielding done:")
|
||||
print('*'*200)
|
||||
to_yield = f"event: done\ndata: {json.dumps(event_data)}\n\n"
|
||||
print(to_yield)
|
||||
print('='*200)
|
||||
yield to_yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {str(e)}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
def generate():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
async def get_all_chunks():
|
||||
chunks = []
|
||||
async for chunk in process_stream():
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
chunks = loop.run_until_complete(get_all_chunks())
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate: {e}")
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting async server...")
|
||||
config = Config()
|
||||
config.bind = ["0.0.0.0:4040"]
|
||||
asyncio.run(serve(app, config))
|
||||
0
apps/rowboat_agents/src/graph/__init__.py
Normal file
0
apps/rowboat_agents/src/graph/__init__.py
Normal file
367
apps/rowboat_agents/src/graph/core.py
Normal file
367
apps/rowboat_agents/src/graph/core.py
Normal file
|
|
@ -0,0 +1,367 @@
|
|||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
import logging
|
||||
from .helpers.access import (
|
||||
get_agent_by_name,
|
||||
get_external_tools,
|
||||
)
|
||||
from .helpers.state import (
|
||||
construct_state_from_response
|
||||
)
|
||||
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
|
||||
from .swarm_wrapper import run as swarm_run, run_streamed as swarm_run_streamed, create_response, get_agents
|
||||
from src.utils.common import common_logger as logger
|
||||
import asyncio
|
||||
|
||||
# Create a dedicated logger for swarm wrapper
|
||||
logger.setLevel(logging.INFO)
|
||||
print("Logger level set to INFO")
|
||||
|
||||
def order_messages(messages):
|
||||
"""
|
||||
Sorts each message's keys in a specified order and returns a new list of ordered messages.
|
||||
"""
|
||||
ordered_messages = []
|
||||
for msg in messages:
|
||||
# Filter out None values
|
||||
msg = {k: v for k, v in msg.items() if v is not None}
|
||||
|
||||
# Specify the exact order
|
||||
ordered = {}
|
||||
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
|
||||
if key in msg:
|
||||
ordered[key] = msg[key]
|
||||
|
||||
# Add remaining keys in alphabetical order
|
||||
remaining_keys = sorted(k for k in msg if k not in ordered)
|
||||
for key in remaining_keys:
|
||||
ordered[key] = msg[key]
|
||||
|
||||
ordered_messages.append(ordered)
|
||||
return ordered_messages
|
||||
|
||||
|
||||
def clean_up_history(agent_data):
|
||||
"""
|
||||
Ensures each agent's history is sorted using order_messages.
|
||||
"""
|
||||
for data in agent_data:
|
||||
data["history"] = order_messages(data["history"])
|
||||
return agent_data
|
||||
|
||||
def create_final_response(response, turn_messages, tokens_used, all_agents):
|
||||
"""
|
||||
Constructs the final response data (messages, tokens_used, updated state) that a caller would need.
|
||||
"""
|
||||
# Ensure response has a messages attribute
|
||||
if not hasattr(response, 'messages'):
|
||||
response.messages = []
|
||||
|
||||
# Assign the appropriate messages to the response
|
||||
response.messages = turn_messages
|
||||
|
||||
# Ensure tokens_used is a valid dictionary
|
||||
if not isinstance(tokens_used, dict):
|
||||
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
|
||||
|
||||
# Ensure response has a tokens_used attribute that's a dictionary
|
||||
if not hasattr(response, 'tokens_used') or not isinstance(response.tokens_used, dict):
|
||||
response.tokens_used = {}
|
||||
|
||||
response.tokens_used = tokens_used
|
||||
|
||||
# Ensure response has an agent attribute for state construction
|
||||
if not hasattr(response, 'agent'):
|
||||
if all_agents and len(all_agents) > 0:
|
||||
response.agent = all_agents[0] # Set default agent if missing
|
||||
|
||||
new_state = construct_state_from_response(response, all_agents)
|
||||
return response.messages, response.tokens_used, new_state
|
||||
|
||||
|
||||
def run_turn(
|
||||
messages, start_agent_name, agent_configs, tool_configs, start_turn_with_start_agent, state={}, additional_tool_configs=[], complete_request={}
|
||||
):
|
||||
"""
|
||||
Coordinates a single 'turn' of conversation or processing among agents.
|
||||
Includes validation, agent setup, optional greeting logic, error handling, and post-processing steps.
|
||||
"""
|
||||
logger.info("Running stateless turn")
|
||||
print("Running stateless turn")
|
||||
|
||||
# Sort messages by the specified ordering
|
||||
#messages = order_messages(messages)
|
||||
|
||||
# Merge any additional tool configs
|
||||
tool_configs = tool_configs + additional_tool_configs
|
||||
|
||||
# Determine if this is a greeting turn
|
||||
greeting_turn = not any(msg.get("role") != "system" for msg in messages)
|
||||
turn_messages = []
|
||||
# Initialize tokens_used as a dictionary
|
||||
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
|
||||
|
||||
agent_data = state.get("agent_data", [])
|
||||
|
||||
# If not a greeting turn, localize the last user or system messages
|
||||
if not greeting_turn:
|
||||
latest_assistant_msg = get_latest_assistant_msg(messages)
|
||||
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
|
||||
msg_type = latest_non_assistant_msgs[-1]["role"]
|
||||
|
||||
# Determine the last agent from state/config
|
||||
last_agent_name = get_last_agent_name(
|
||||
state=state,
|
||||
agent_configs=agent_configs,
|
||||
start_agent_name=start_agent_name,
|
||||
msg_type=msg_type,
|
||||
latest_assistant_msg=latest_assistant_msg,
|
||||
start_turn_with_start_agent=start_turn_with_start_agent
|
||||
)
|
||||
else:
|
||||
# For a greeting turn, we assume the last agent is the start_agent_name
|
||||
last_agent_name = start_agent_name
|
||||
|
||||
state["agent_data"] = agent_data
|
||||
|
||||
# Initialize all agents
|
||||
logger.info("Initializing agents")
|
||||
print("Initializing agents")
|
||||
new_agents = get_agents(
|
||||
agent_configs=agent_configs,
|
||||
tool_configs=tool_configs,
|
||||
complete_request=complete_request
|
||||
)
|
||||
# Prepare escalation agent
|
||||
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
|
||||
|
||||
# Gather external tools for Swarm
|
||||
external_tools = get_external_tools(tool_configs)
|
||||
logger.info(f"Found {len(external_tools)} external tools")
|
||||
print(f"Found {len(external_tools)} external tools")
|
||||
|
||||
# If no validation error yet, proceed with the main run
|
||||
|
||||
logger.info("Running swarm run")
|
||||
print("Running swarm run")
|
||||
|
||||
response = swarm_run(
|
||||
agent=last_new_agent,
|
||||
messages=messages,
|
||||
external_tools=external_tools,
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
|
||||
logger.info("Swarm run completed")
|
||||
print("Swarm run completed")
|
||||
|
||||
# Initialize response.messages if it doesn't exist
|
||||
if not hasattr(response, 'messages'):
|
||||
response.messages = []
|
||||
|
||||
# Convert the ResponseOutputMessage to a standard message format
|
||||
if hasattr(response, 'new_items') and response.new_items and hasattr(response.new_items[-1], 'raw_item'):
|
||||
raw_item = response.new_items[-1].raw_item
|
||||
# Extract text content from ResponseOutputText objects
|
||||
content = ""
|
||||
if hasattr(raw_item, 'content') and raw_item.content:
|
||||
for content_item in raw_item.content:
|
||||
if hasattr(content_item, 'text'):
|
||||
content += content_item.text
|
||||
|
||||
# Create a standard message dictionary
|
||||
standard_message = {
|
||||
"role": raw_item.role if hasattr(raw_item, 'role') else "assistant",
|
||||
"content": content,
|
||||
"sender": last_new_agent.name,
|
||||
"created_at": None,
|
||||
"response_type": "internal"
|
||||
}
|
||||
|
||||
# Add the converted message to response messages
|
||||
response.messages.append(standard_message)
|
||||
|
||||
logger.info("Converted message added to response messages")
|
||||
print("Converted message added to response messages")
|
||||
|
||||
# Use a dictionary for tokens_used instead of a hard-coded integer
|
||||
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Dummy values as placeholders
|
||||
|
||||
# Ensure turn_messages can be extended with response.messages
|
||||
if hasattr(response, 'messages') and isinstance(response.messages, list):
|
||||
turn_messages.extend(response.messages)
|
||||
|
||||
logger.info(f"Completed run of agent: {last_new_agent.name}")
|
||||
print(f"Completed run of agent: {last_new_agent.name}")
|
||||
|
||||
|
||||
# Otherwise, duplicate the last response as external
|
||||
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
|
||||
print("No post-processing agent found. Duplicating last response and setting to external.")
|
||||
if turn_messages:
|
||||
duplicate_msg = deepcopy(turn_messages[-1])
|
||||
duplicate_msg["response_type"] = "external"
|
||||
duplicate_msg["sender"] += " >> External"
|
||||
|
||||
# Ensure tokens_used remains a proper dictionary
|
||||
if not isinstance(tokens_used, dict):
|
||||
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
|
||||
|
||||
response = create_response(
|
||||
messages=[duplicate_msg],
|
||||
tokens_used=tokens_used,
|
||||
agent=last_new_agent,
|
||||
error_msg=''
|
||||
)
|
||||
|
||||
# Ensure response has messages attribute
|
||||
if hasattr(response, 'messages') and isinstance(response.messages, list):
|
||||
turn_messages.extend(response.messages)
|
||||
|
||||
# Finalize the response
|
||||
logger.info("Finalizing response")
|
||||
print("Finalizing response")
|
||||
return create_final_response(
|
||||
response=response,
|
||||
turn_messages=turn_messages,
|
||||
tokens_used=tokens_used,
|
||||
all_agents=new_agents
|
||||
)
|
||||
|
||||
async def run_turn_streamed(
|
||||
messages,
|
||||
start_agent_name,
|
||||
agent_configs,
|
||||
tool_configs,
|
||||
start_turn_with_start_agent,
|
||||
state={},
|
||||
additional_tool_configs=[],
|
||||
complete_request={}
|
||||
):
|
||||
final_state = None # Initialize outside try block
|
||||
try:
|
||||
# Initialize agents and get external tools
|
||||
new_agents = get_agents(agent_configs=agent_configs, tool_configs=tool_configs, complete_request=complete_request)
|
||||
last_agent_name = get_last_agent_name(
|
||||
state=state,
|
||||
agent_configs=agent_configs,
|
||||
start_agent_name=start_agent_name,
|
||||
msg_type="user",
|
||||
latest_assistant_msg=None,
|
||||
start_turn_with_start_agent=start_turn_with_start_agent
|
||||
)
|
||||
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
|
||||
external_tools = get_external_tools(tool_configs)
|
||||
|
||||
current_agent = last_new_agent
|
||||
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
|
||||
|
||||
stream_result = await swarm_run_streamed(
|
||||
agent=last_new_agent,
|
||||
messages=messages,
|
||||
external_tools=external_tools,
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
|
||||
# Process streaming events
|
||||
async for event in stream_result.stream_events():
|
||||
# print('='*50)
|
||||
# print("Received event: ", event)
|
||||
# print('-'*50)
|
||||
|
||||
# Handle raw response events and accumulate tokens
|
||||
if event.type == "raw_response_event":
|
||||
if hasattr(event.data, 'type') and event.data.type == "response.completed":
|
||||
if hasattr(event.data.response, 'usage'):
|
||||
tokens_used["total"] += event.data.response.usage.total_tokens
|
||||
tokens_used["prompt"] += event.data.response.usage.input_tokens
|
||||
tokens_used["completion"] += event.data.response.usage.output_tokens
|
||||
print('-'*50)
|
||||
print(f"Found usage information. Updated cumulative tokens: {tokens_used}")
|
||||
print('-'*50)
|
||||
continue
|
||||
|
||||
# Update current agent when it changes
|
||||
elif event.type == "agent_updated_stream_event":
|
||||
current_agent = event.new_agent
|
||||
message = {
|
||||
'content': f"Agent changed to {current_agent.name}",
|
||||
'role': 'assistant',
|
||||
'sender': current_agent.name,
|
||||
'tool_calls': None,
|
||||
'tool_call_id': None,
|
||||
'response_type': 'internal'
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
continue
|
||||
|
||||
# Handle run items (tools, messages, etc)
|
||||
elif event.type == "run_item_stream_event":
|
||||
if event.item.type == "tool_call_item":
|
||||
message = {
|
||||
'content': None,
|
||||
'role': 'assistant',
|
||||
'sender': current_agent.name if current_agent else None,
|
||||
'tool_calls': [{
|
||||
'function': {
|
||||
'name': event.item.raw_item.name,
|
||||
'arguments': event.item.raw_item.arguments
|
||||
},
|
||||
'id': event.item.raw_item.id,
|
||||
'type': 'function'
|
||||
}],
|
||||
'tool_call_id': None,
|
||||
'tool_name': None,
|
||||
'response_type': 'internal'
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
elif event.item.type == "tool_call_output_item":
|
||||
message = {
|
||||
'content': str(event.item.output),
|
||||
'role': 'tool',
|
||||
'sender': None,
|
||||
'tool_calls': None,
|
||||
'tool_call_id': event.item.raw_item['call_id'],
|
||||
'tool_name': event.item.raw_item.get('name', None),
|
||||
'response_type': 'internal'
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
elif event.item.type == "message_output_item":
|
||||
content = ""
|
||||
if hasattr(event.item.raw_item, 'content'):
|
||||
for content_item in event.item.raw_item.content:
|
||||
if hasattr(content_item, 'text'):
|
||||
content += content_item.text
|
||||
|
||||
message = {
|
||||
'content': content,
|
||||
'role': 'assistant',
|
||||
'sender': current_agent.name,
|
||||
'tool_calls': None,
|
||||
'tool_call_id': None,
|
||||
'tool_name': None,
|
||||
'response_type': 'external'
|
||||
}
|
||||
print("Yielding message: ", message)
|
||||
yield ('message', message)
|
||||
|
||||
print(f"\n{'='*50}\n")
|
||||
|
||||
# After all events are processed, set final state
|
||||
final_state = {
|
||||
"last_agent_name": current_agent.name if current_agent else None,
|
||||
"tokens": tokens_used
|
||||
}
|
||||
yield ('done', {'state': final_state})
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print(f"Error in stream processing: {str(e)}")
|
||||
yield ('error', {'error': str(e), 'state': final_state}) # Include final_state in error response
|
||||
218
apps/rowboat_agents/src/graph/guardrails.py
Normal file
218
apps/rowboat_agents/src/graph/guardrails.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
# Guardrails
|
||||
from src.utils.common import generate_llm_output
|
||||
import os
|
||||
import copy
|
||||
|
||||
from .swarm_wrapper import Agent, Response, create_response
|
||||
|
||||
from src.utils.common import common_logger, generate_openai_output, update_tokens_used
|
||||
logger = common_logger
|
||||
|
||||
def classify_hallucination(context: str, assistant_response: str, chat_history: list, model: str) -> str:
|
||||
"""
|
||||
Checks if an assistant's response contains hallucinations by comparing against provided context.
|
||||
|
||||
Args:
|
||||
context (str): The context/knowledge base to check the response against
|
||||
assistant_response (str): The response from the assistant to validate
|
||||
chat_history (list): List of previous chat messages for context
|
||||
|
||||
Returns:
|
||||
str: Verdict indicating level of hallucination:
|
||||
'yes-absolute' - completely supported by context
|
||||
'yes-common-sensical' - supported with common sense interpretation
|
||||
'no-absolute' - not supported by context
|
||||
'no-subtle' - not supported but difference is subtle
|
||||
"""
|
||||
chat_history_str = "\n".join([f"{message['role']}: {message['content']}" for message in chat_history])
|
||||
|
||||
prompt = f"""
|
||||
You are a guardrail agent. Your job is to check if the response is hallucinating.
|
||||
|
||||
------------------------------------------------------------------------
|
||||
Here is the context:
|
||||
{context}
|
||||
|
||||
------------------------------------------------------------------------
|
||||
Here is the chat history message:
|
||||
{chat_history_str}
|
||||
|
||||
------------------------------------------------------------------------
|
||||
Here is the response:
|
||||
{assistant_response}
|
||||
|
||||
------------------------------------------------------------------------
|
||||
As a hallucination guardrail, your job is to go through each line of the response and check if it is completely supported by the context. Even if a single line is not supported, the response is no.
|
||||
|
||||
Output a single verdict for the entire response. don't provide any reasoning. The output classes are
|
||||
|
||||
yes-absolute: completely supported by the context
|
||||
yes-common-sensical: but with some common sense interpretation
|
||||
no-absolute: not supported by the context
|
||||
no-subtle: not supported by the context but the difference is subtle
|
||||
|
||||
Output of of the classes:
|
||||
verdict : yes-absolute/yes-common-sensical/no-absolute/no-subtle
|
||||
|
||||
Example 1: The response is completely supported by the context.
|
||||
User Input:
|
||||
Context: "Our airline provides complimentary meals and beverages on all international flights. Passengers are allowed one carry-on bag and one personal item."
|
||||
Chat History:
|
||||
User: "Do international flights with your airline offer free meals?"
|
||||
Response: "Yes, all international flights with our airline offer free meals and beverages."
|
||||
Output: verdict: yes-absolute
|
||||
|
||||
Example 2: The response is generally true and could be deduced with common sense interpretation, though not explicitly stated in the context.
|
||||
User Input:
|
||||
Context: "Flights may experience delays due to weather conditions. In such cases, the airline staff will provide updates at the airport."
|
||||
Chat History:
|
||||
User: "Will there be announcements if my flight is delayed?"
|
||||
Response: "Yes, if your flight is delayed, there will be announcements at the airport."
|
||||
Output: verdict: yes-common-sensical
|
||||
|
||||
Example 3: The response is not supported by the context and contains glaring inaccuracies.
|
||||
User Input:
|
||||
Context: "You can cancel your ticket online up to 24 hours before the flight's departure time and receive a full refund."
|
||||
Chat History:
|
||||
User: "Can I get a refund if I cancel 12 hours before the flight?"
|
||||
Response: "Yes, you can get a refund if you cancel 12 hours before the flight."
|
||||
Output: verdict: no-absolute
|
||||
|
||||
Example 4: The response is not supported by the context but the difference is subtle.
|
||||
User Input:
|
||||
Context: "Our frequent flyer program offers discounts on checked bags for members who have achieved Gold status."
|
||||
Chat History:
|
||||
User: "As a member, do I get discounts on checked bags?"
|
||||
Response: "Yes, members of our frequent flyer program get discounts on checked bags."
|
||||
Output: verdict: no-subtle
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
response = generate_llm_output(messages, model)
|
||||
return response
|
||||
|
||||
def post_process_response(messages: list, post_processing_agent_name: str, post_process_instructions: str, style_prompt: str = None, context: str = None, model: str = "gpt-4o", tokens_used: dict = {}, last_agent: Agent = None) -> dict:
|
||||
agent_instructions = last_agent.instructions
|
||||
agent_history = last_agent.history
|
||||
# agent_instructions = ''
|
||||
# agent_history = []
|
||||
|
||||
pending_msg = copy.deepcopy(messages[-1])
|
||||
logger.debug(f"Pending message keys: {pending_msg.keys()}")
|
||||
|
||||
skip = False
|
||||
|
||||
if pending_msg.get("tool_calls"):
|
||||
logger.info("Last message is a tool call, skipping post processing and setting last message to external")
|
||||
skip = True
|
||||
|
||||
elif not pending_msg['response_type'] == "internal":
|
||||
logger.info("Last message is not internal, skipping post processing and setting last message to external")
|
||||
skip = True
|
||||
|
||||
elif not pending_msg['content']:
|
||||
logger.info("Last message has no content, skipping post processing and setting last message to external")
|
||||
skip = True
|
||||
|
||||
elif not post_process_instructions:
|
||||
logger.info("No post process instructions, skipping post processing and setting last message to external")
|
||||
skip = True
|
||||
|
||||
if skip:
|
||||
pending_msg['response_type'] = "external"
|
||||
response = Response(
|
||||
messages=[],
|
||||
tokens_used=tokens_used,
|
||||
agent=last_agent,
|
||||
error_msg=''
|
||||
)
|
||||
return response
|
||||
|
||||
agent_history_str = f"\n{'*'*100}\n".join([f"Role: {message['role']} | Content: {message.get('content', 'None')} | Tool Calls: {message.get('tool_calls', 'None')}" for message in agent_history[:-1]])
|
||||
logger.debug(f"Agent history: {agent_history_str}")
|
||||
|
||||
prompt = f"""
|
||||
# ROLE
|
||||
|
||||
You are a post processing agent responsible for rewriting a response generated by an agent, according to instructions provided below. Ensure that the response you produce adheres to the instructions provided to you (if any).
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# ADDITIONAL INSTRUCTIONS
|
||||
|
||||
Here are additional instructions that the admin might have configured for you:
|
||||
{post_process_instructions}
|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
# CHAT HISTORY
|
||||
|
||||
Here is the chat history:
|
||||
{agent_history_str}
|
||||
"""
|
||||
if context:
|
||||
context_prompt = f"""
|
||||
------------------------------------------------------------------------
|
||||
# CONTEXT
|
||||
|
||||
Here is the context:
|
||||
{context}
|
||||
"""
|
||||
prompt += context_prompt
|
||||
|
||||
if style_prompt:
|
||||
style_prompt = f"""
|
||||
------------------------------------------------------------------------
|
||||
# STYLE PROMPT
|
||||
|
||||
Here is the style prompt:
|
||||
{style_prompt}
|
||||
"""
|
||||
prompt += style_prompt
|
||||
|
||||
agent_response_and_instructions = f"""
|
||||
|
||||
------------------------------------------------------------------------
|
||||
# AGENT INSTRUCTIONS
|
||||
|
||||
Here are the instructions to the agent generating the response:
|
||||
{agent_instructions}
|
||||
|
||||
------------------------------------------------------------------------
|
||||
# AGENT RESPONSE
|
||||
|
||||
Here is the response that the agent has generated:
|
||||
{pending_msg['content']}
|
||||
|
||||
"""
|
||||
prompt += agent_response_and_instructions
|
||||
|
||||
logger.debug(f"Sanitizing response for style. Original response: {pending_msg['content']}")
|
||||
completion = generate_openai_output(
|
||||
messages=[
|
||||
{"role": "system", "content": prompt}
|
||||
],
|
||||
model = model,
|
||||
return_completion=True
|
||||
)
|
||||
content = completion.choices[0].message.content
|
||||
if content:
|
||||
content = content.strip().lstrip().rstrip()
|
||||
tokens_used = update_tokens_used(provider="openai", model=model, tokens_used=tokens_used, completion=completion)
|
||||
logger.debug(f"Response after style check: {content}, tokens used: {tokens_used}")
|
||||
|
||||
pending_msg['content'] = content if content else pending_msg['content']
|
||||
pending_msg['response_type'] = "external"
|
||||
pending_msg['sender'] = pending_msg['sender'] + f' >> {post_processing_agent_name}'
|
||||
|
||||
response = Response(
|
||||
messages=[pending_msg],
|
||||
tokens_used=tokens_used,
|
||||
agent=last_agent,
|
||||
error_msg=''
|
||||
)
|
||||
|
||||
return response
|
||||
48
apps/rowboat_agents/src/graph/helpers/access.py
Normal file
48
apps/rowboat_agents/src/graph/helpers/access.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
|
||||
def get_external_tools(tool_configs):
|
||||
logger.debug("Getting external tools")
|
||||
tools = [tool["name"] for tool in tool_configs]
|
||||
logger.debug(f"Found {len(tools)} external tools")
|
||||
return tools
|
||||
|
||||
def get_agent_by_name(agent_name, agents):
|
||||
agent = next((a for a in agents if getattr(a, "name", None) == agent_name), None)
|
||||
if not agent:
|
||||
logger.error(f"Agent with name {agent_name} not found")
|
||||
raise ValueError(f"Agent with name {agent_name} not found")
|
||||
return agent
|
||||
|
||||
def get_agent_config_by_name(agent_name, agent_configs):
|
||||
agent_config = next((ac for ac in agent_configs if ac.get("name") == agent_name), None)
|
||||
if not agent_config:
|
||||
logger.error(f"Agent config with name {agent_name} not found")
|
||||
raise ValueError(f"Agent config with name {agent_name} not found")
|
||||
return agent_config
|
||||
|
||||
def pop_agent_config_by_type(agent_configs, agent_type):
|
||||
agent_config = next((ac for ac in agent_configs if ac.get("type") == agent_type), None)
|
||||
if agent_config:
|
||||
agent_configs.remove(agent_config)
|
||||
return agent_config, agent_configs
|
||||
|
||||
def get_agent_by_type(agents, agent_type):
|
||||
return next((a for a in agents if a.type == agent_type), None)
|
||||
|
||||
def get_prompt_by_type(prompt_configs, prompt_type):
|
||||
return next((pc.get("prompt") for pc in prompt_configs if pc.get("type") == prompt_type), None)
|
||||
|
||||
def get_agent_data_by_name(agent_name, agent_data):
|
||||
for data in agent_data:
|
||||
name = data.get("name", "")
|
||||
if name == agent_name:
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_config_by_name(tool_configs, tool_name):
|
||||
return next((tc for tc in tool_configs if tc.get("name", "") == tool_name), None)
|
||||
|
||||
def get_tool_config_by_type(tool_configs, tool_type):
|
||||
return next((tc for tc in tool_configs if tc.get("type", "") == tool_type), None)
|
||||
50
apps/rowboat_agents/src/graph/helpers/control.py
Normal file
50
apps/rowboat_agents/src/graph/helpers/control.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from .access import get_agent_config_by_name, get_agent_data_by_name
|
||||
from src.graph.types import ControlType
|
||||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
|
||||
def get_last_agent_name(state, agent_configs, start_agent_name, msg_type, latest_assistant_msg, start_turn_with_start_agent):
|
||||
default_last_agent_name = state.get("last_agent_name", '')
|
||||
last_agent_config = get_agent_config_by_name(default_last_agent_name, agent_configs)
|
||||
specific_agent_data = get_agent_data_by_name(default_last_agent_name, state.get("agent_data", []))
|
||||
|
||||
# Overrides for special cases
|
||||
logger.info("Setting agent control based on last agent and control type")
|
||||
if msg_type == "tool":
|
||||
last_agent_name = default_last_agent_name
|
||||
assert last_agent_name == latest_assistant_msg.get("sender", ''), "Last agent name does not match sender of latest assistant message during tool call handling"
|
||||
|
||||
elif start_turn_with_start_agent:
|
||||
last_agent_name = start_agent_name
|
||||
|
||||
else:
|
||||
control_type = last_agent_config.get("controlType", ControlType.RETAIN.value)
|
||||
if control_type == ControlType.PARENT_AGENT.value:
|
||||
last_agent_name = specific_agent_data.get("most_recent_parent_name", None) if specific_agent_data else None
|
||||
if not last_agent_name:
|
||||
logger.error("Most recent parent is empty, defaulting to same agent instead")
|
||||
last_agent_name = default_last_agent_name
|
||||
elif control_type == ControlType.START_AGENT.value:
|
||||
last_agent_name = start_agent_name
|
||||
else:
|
||||
last_agent_name = default_last_agent_name
|
||||
|
||||
if default_last_agent_name != last_agent_name:
|
||||
logger.info(f"Last agent name changed from {default_last_agent_name} to {last_agent_name} due to control settings")
|
||||
|
||||
return last_agent_name
|
||||
|
||||
|
||||
def get_latest_assistant_msg(messages):
|
||||
# Find the latest message with role assistant
|
||||
for i in range(len(messages)-1, -1, -1):
|
||||
if messages[i].get("role") == "assistant":
|
||||
return messages[i]
|
||||
return None
|
||||
|
||||
def get_latest_non_assistant_messages(messages):
|
||||
# Find all messages after the last assistant message
|
||||
for i in range(len(messages)-1, -1, -1):
|
||||
if messages[i].get("role") == "assistant":
|
||||
return messages[i+1:]
|
||||
return messages
|
||||
39
apps/rowboat_agents/src/graph/helpers/instructions.py
Normal file
39
apps/rowboat_agents/src/graph/helpers/instructions.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from src.graph.instructions import TRANSFER_CHILDREN_INSTRUCTIONS, TRANSFER_PARENT_AWARE_INSTRUCTIONS, RAG_INSTRUCTIONS, ERROR_ESCALATION_AGENT_INSTRUCTIONS, TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS, SYSTEM_MESSAGE
|
||||
|
||||
def add_transfer_instructions_to_parent_agents(agent, children, transfer_functions):
|
||||
other_agent_name_descriptions_tools = f'\n{'-'*100}\n'.join([f"Name: {agent.name}\nDescription: {agent.description if agent.description else ''}\nTool for transfer: {transfer_functions[agent.name].__name__}" for agent in children.values()])
|
||||
|
||||
prompt = TRANSFER_CHILDREN_INSTRUCTIONS.format(other_agent_name_descriptions_tools=other_agent_name_descriptions_tools)
|
||||
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
|
||||
|
||||
return agent
|
||||
|
||||
def add_transfer_instructions_to_child_agents(child, children_aware_of_parent):
|
||||
if children_aware_of_parent:
|
||||
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join([f"Name: {parent_name}\nTool for transfer: {func.__name__}" for parent_name, func in child.candidate_parent_functions.items()])
|
||||
prompt = TRANSFER_PARENT_AWARE_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
|
||||
else:
|
||||
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join(list(set([f"Tool for transfer: {func.__name__}" for _, func in child.candidate_parent_functions.items()])))
|
||||
prompt = TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
|
||||
|
||||
child.instructions = child.instructions + f'\n\n{'-'*100}\n\n' + prompt
|
||||
return child
|
||||
|
||||
def add_rag_instructions_to_agent(agent_config, rag_tool_name):
|
||||
prompt = RAG_INSTRUCTIONS.format(rag_tool_name=rag_tool_name)
|
||||
agent_config["instructions"] = agent_config["instructions"] + f'\n\n{'-'*100}\n\n' + prompt
|
||||
return agent_config
|
||||
|
||||
def add_error_escalation_instructions(agent):
|
||||
prompt = ERROR_ESCALATION_AGENT_INSTRUCTIONS
|
||||
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
|
||||
return agent
|
||||
|
||||
def get_universal_system_message(messages):
|
||||
if messages and messages[0].get("role") == "system":
|
||||
return SYSTEM_MESSAGE.format(system_message=messages[0].get("content"))
|
||||
return ""
|
||||
|
||||
def add_universal_system_message_to_agent(agent, universal_sys_msg):
|
||||
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + universal_sys_msg
|
||||
return agent
|
||||
60
apps/rowboat_agents/src/graph/helpers/state.py
Normal file
60
apps/rowboat_agents/src/graph/helpers/state.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
from .access import get_agent_data_by_name
|
||||
|
||||
def reset_current_turn(messages):
|
||||
# Set all messages' current_turn to False
|
||||
for msg in messages:
|
||||
msg["current_turn"] = False
|
||||
|
||||
# Find most recent user message
|
||||
messages[-1]["current_turn"] = True
|
||||
|
||||
return messages
|
||||
|
||||
def reset_current_turn_agent_history(agent_data, agent_names):
|
||||
for name in agent_names:
|
||||
data = get_agent_data_by_name(name, agent_data)
|
||||
if data:
|
||||
for msg in data["history"]:
|
||||
msg["current_turn"] = False
|
||||
return agent_data
|
||||
|
||||
def add_recent_messages_to_history(recent_messages, last_agent_name, agent_data, messages, parent_has_child_history):
|
||||
last_msg = messages[-1]
|
||||
specific_agent_data = get_agent_data_by_name(last_agent_name, agent_data)
|
||||
if specific_agent_data:
|
||||
specific_agent_data["history"].extend(recent_messages)
|
||||
if parent_has_child_history:
|
||||
current_agent_data = specific_agent_data
|
||||
while current_agent_data.get("most_recent_parent_name"):
|
||||
parent_name = current_agent_data.get("most_recent_parent_name")
|
||||
parent_agent_data = get_agent_data_by_name(parent_name, agent_data)
|
||||
if parent_agent_data:
|
||||
parent_agent_data["history"].extend(recent_messages)
|
||||
current_agent_data = parent_agent_data
|
||||
else:
|
||||
logger.error(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
|
||||
raise ValueError(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
|
||||
else:
|
||||
agent_data.append({
|
||||
"name": last_agent_name,
|
||||
"history": [last_msg]
|
||||
})
|
||||
|
||||
return agent_data
|
||||
|
||||
def construct_state_from_response(response, agents):
|
||||
agent_data = []
|
||||
for agent in agents:
|
||||
agent_data.append({
|
||||
"name": agent.name,
|
||||
"instructions": agent.instructions
|
||||
})
|
||||
|
||||
state = {
|
||||
"last_agent_name": response.agent.name,
|
||||
"agent_data": agent_data
|
||||
}
|
||||
|
||||
return state
|
||||
44
apps/rowboat_agents/src/graph/helpers/transfer.py
Normal file
44
apps/rowboat_agents/src/graph/helpers/transfer.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
|
||||
def create_transfer_function_to_agent(agent):
|
||||
agent_name = agent.name
|
||||
|
||||
fn_spec = {
|
||||
"name": f"transfer_to_{agent_name.lower().replace(' ', '_')}",
|
||||
"description": f"Function to transfer the chat to {agent_name}.",
|
||||
"return_value": agent
|
||||
}
|
||||
|
||||
def generated_function(*args, **kwargs):
|
||||
logger.info(f"Transferring chat to {agent_name}")
|
||||
return fn_spec.get('return_value', None)
|
||||
|
||||
generated_function.__name__ = fn_spec['name']
|
||||
generated_function.__doc__ = fn_spec.get('description', '')
|
||||
|
||||
return generated_function
|
||||
|
||||
def create_transfer_function_to_parent_agent(parent_agent, children_aware_of_parent, transfer_functions):
|
||||
if children_aware_of_parent:
|
||||
name = f"{transfer_functions[parent_agent.name].__name__}_from_child"
|
||||
description = f"Function to transfer the chat to your parent agent: {parent_agent.name}."
|
||||
else:
|
||||
name = "give_up_chat_control"
|
||||
description = "Function to give up control of the chat when you are unable to handle it."
|
||||
|
||||
|
||||
fn_spec = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"return_value": parent_agent
|
||||
}
|
||||
|
||||
def generated_function(*args, **kwargs):
|
||||
logger.info(f"Transferring chat to parent agent: {parent_agent.name}")
|
||||
return fn_spec.get('return_value', None)
|
||||
|
||||
generated_function.__name__ = fn_spec['name']
|
||||
generated_function.__doc__ = fn_spec.get('description', '')
|
||||
|
||||
return generated_function
|
||||
70
apps/rowboat_agents/src/graph/instructions.py
Normal file
70
apps/rowboat_agents/src/graph/instructions.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
########################
|
||||
# Instructions for agents that use RAG
|
||||
########################
|
||||
RAG_INSTRUCTIONS = f"""
|
||||
# Instructions about using the article retrieval tool
|
||||
- Where relevant, use the articles tool: {{rag_tool_name}} to fetch articles with knowledge relevant to the query and use its contents to respond to the user.
|
||||
- Do not send a separate message first asking the user to wait while you look up information. Immediately fetch the articles and respond to the user with the answer to their query.
|
||||
- Do not make up information. If the article's contents do not have the answer, give up control of the chat (or transfer to your parent agent, as per your transfer instructions). Do not say anything to the user.
|
||||
"""
|
||||
|
||||
########################
|
||||
# Instructions for child agents that are aware of parent agents
|
||||
########################
|
||||
TRANSFER_PARENT_AWARE_INSTRUCTIONS = f"""
|
||||
# Instructions about using your parent agents
|
||||
You have the following candidate parent agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
|
||||
{{candidate_parents_name_description_tools}}.
|
||||
|
||||
## Notes:
|
||||
- During runtime, you will be provided with a tool call for exactly one of these parent agents that you can use. Use that tool call to transfer the chat to the parent agent in case you are unable to handle the chat (e.g. if it is not in your scope of instructions).
|
||||
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
|
||||
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
|
||||
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
|
||||
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
|
||||
"""
|
||||
|
||||
########################
|
||||
# Instructions for child agents that give up control to parent agents
|
||||
########################
|
||||
TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS = f"""
|
||||
# Instructions about giving up chat control
|
||||
If you are unable to handle the chat (e.g. if it is not in your scope of instructions), you should use the tool call provided to give up control of the chat.
|
||||
{{candidate_parents_name_description_tools}}
|
||||
|
||||
## Notes:
|
||||
- When you give up control of the chat, you should not provide any response to the user. Just invoke the tool call to give up control.
|
||||
"""
|
||||
|
||||
########################
|
||||
# Instructions for parent agents that need to transfer the chat to other specialized (children) agents
|
||||
########################
|
||||
TRANSFER_CHILDREN_INSTRUCTIONS = f"""
|
||||
# Instructions about using other specialized agents
|
||||
You have the following specialized agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
|
||||
{{other_agent_name_descriptions_tools}}
|
||||
|
||||
## Notes:
|
||||
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
|
||||
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
|
||||
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
|
||||
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
|
||||
"""
|
||||
|
||||
|
||||
########################
|
||||
# Additional instruction for escalation agent when called due to an error
|
||||
########################
|
||||
ERROR_ESCALATION_AGENT_INSTRUCTIONS = f"""
|
||||
# Context
|
||||
The rest of the parts of the chatbot were unable to handle the chat. Hence, the chat has been escalated to you. In addition to your other instructions, tell the user that you are having trouble handling the chat - say "I'm having trouble helping with your request. Sorry about that.". Remember you are a part of the chatbot as well.
|
||||
"""
|
||||
|
||||
|
||||
########################
|
||||
# Universal system message formatting
|
||||
########################
|
||||
SYSTEM_MESSAGE = f"""
|
||||
# Additional System-Wide Context or Instructions:
|
||||
{{system_message}}
|
||||
"""
|
||||
391
apps/rowboat_agents/src/graph/swarm_wrapper.py
Normal file
391
apps/rowboat_agents/src/graph/swarm_wrapper.py
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
import logging
|
||||
import json
|
||||
import aiohttp
|
||||
# Import helper functions needed for get_agents
|
||||
from .helpers.access import (
|
||||
get_tool_config_by_name,
|
||||
get_tool_config_by_type
|
||||
)
|
||||
from .helpers.instructions import (
|
||||
add_rag_instructions_to_agent
|
||||
)
|
||||
|
||||
from agents import Agent as NewAgent, Runner, FunctionTool, RunContextWrapper
|
||||
# Add import for OpenAI functionality
|
||||
from src.utils.common import common_logger as logger, generate_openai_output
|
||||
from typing import Any
|
||||
from dataclasses import asdict
|
||||
import asyncio
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict
|
||||
from .tool_calling import call_rag_tool
|
||||
|
||||
class NewResponse(BaseModel):
|
||||
messages: List[Dict]
|
||||
agent: Optional[Any] = None
|
||||
tokens_used: Optional[dict] = {}
|
||||
error_msg: Optional[str] = ""
|
||||
|
||||
async def mock_tool(tool_name: str, args: str, tool_config: str) -> str:
|
||||
"""
|
||||
Handles tool execution by either using mock instructions or generating a response.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool
|
||||
args: The arguments passed to the tool
|
||||
tool_config: The configuration of the tool
|
||||
|
||||
Returns:
|
||||
The response from the tool
|
||||
"""
|
||||
print(f"Mock tool called for: {tool_name}")
|
||||
|
||||
# For non-mocked tools, generate a realistic response
|
||||
description = tool_config.get("description", "")
|
||||
mock_instructions = tool_config.get("mockInstructions", "")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'.Here is the description of the tool: {description}. Here are the instructions for the mock tool: {mock_instructions}. Generate a realistic response as if the tool was actually executed with the given parameters."},
|
||||
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {args}. The response should be concise and focused on what the tool would actually return."}
|
||||
]
|
||||
|
||||
print(f"Generating simulated response for tool: {tool_name}")
|
||||
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
|
||||
return response_content
|
||||
|
||||
async def call_webhook(tool_name: str, args: str, webhook_url: str) -> str:
|
||||
"""
|
||||
Calls the webhook with the given tool name and arguments.
|
||||
|
||||
Args:
|
||||
tool_name (str): The name of the tool to call.
|
||||
args (str): The arguments for the tool as a JSON string.
|
||||
|
||||
Returns:
|
||||
str: The response from the webhook, or an error message if the call fails.
|
||||
"""
|
||||
content_dict = {
|
||||
"toolCall": {
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args
|
||||
}
|
||||
}
|
||||
}
|
||||
request_body = {
|
||||
"content": json.dumps(content_dict)
|
||||
}
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(webhook_url, json=request_body) as response:
|
||||
if response.status == 200:
|
||||
response_json = await response.json()
|
||||
return response_json.get("result", "")
|
||||
else:
|
||||
error_msg = await response.text()
|
||||
print(f"Webhook error: {error_msg}")
|
||||
return f"Error: {error_msg}"
|
||||
except Exception as e:
|
||||
print(f"Exception in call_webhook: {str(e)}")
|
||||
return f"Error: Failed to call webhook - {str(e)}"
|
||||
|
||||
async def call_mcp(tool_name: str, args: str, mcp_server_url: str) -> str:
|
||||
"""
|
||||
Calls the MCP with the given tool name and arguments.
|
||||
"""
|
||||
|
||||
async with sse_client(url=mcp_server_url) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
jargs = json.loads(args)
|
||||
response = await session.call_tool(tool_name, arguments=jargs)
|
||||
json_output = json.dumps([item.__dict__ for item in response.content], indent=2)
|
||||
|
||||
return json_output
|
||||
|
||||
async def catch_all(ctx: RunContextWrapper[Any], args: str, tool_name: str, tool_config: dict, complete_request: dict) -> str:
|
||||
"""
|
||||
Handles all tool calls by dispatching to appropriate functions.
|
||||
"""
|
||||
print(f"Catch all called for tool: {tool_name}")
|
||||
print(f"Args: {args}")
|
||||
print(f"Tool config: {tool_config}")
|
||||
|
||||
# Create event loop for async operations
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
response_content = None
|
||||
if tool_config.get("mockTool", False):
|
||||
# Call mock_tool to handle the response (it will decide whether to use mock instructions or generate a response)
|
||||
response_content = await mock_tool(tool_name, args, tool_config)
|
||||
print(response_content)
|
||||
elif tool_config.get("isMcp", False):
|
||||
mcp_server_name = tool_config.get("mcpServerName", "")
|
||||
mcp_servers = complete_request.get("mcpServers", {})
|
||||
mcp_server_url = next((server.get("url", "") for server in mcp_servers if server.get("name") == mcp_server_name), "")
|
||||
response_content = await call_mcp(tool_name, args, mcp_server_url)
|
||||
else:
|
||||
webhook_url = complete_request.get("toolWebhookUrl", "")
|
||||
response_content = await call_webhook(tool_name, args, webhook_url)
|
||||
return response_content
|
||||
|
||||
|
||||
def get_rag_tool(config: dict, complete_request: dict) -> FunctionTool:
|
||||
"""
|
||||
Creates a RAG tool based on the provided configuration.
|
||||
"""
|
||||
project_id = complete_request.get("projectId", "")
|
||||
if config.get("ragDataSources", None):
|
||||
print("getArticleInfo")
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The query to search for"
|
||||
}
|
||||
},
|
||||
"additionalProperties": False,
|
||||
"required": [
|
||||
"query"
|
||||
]
|
||||
}
|
||||
tool = FunctionTool(
|
||||
name="getArticleInfo",
|
||||
description="Get information about an article",
|
||||
params_json_schema=params,
|
||||
on_invoke_tool=lambda ctx, args: call_rag_tool(project_id, json.loads(args)['query'], config.get("ragDataSources", []), "chunks", 3)
|
||||
)
|
||||
return tool
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_agents(agent_configs, tool_configs, complete_request):
|
||||
"""
|
||||
Creates and initializes Agent objects based on their configurations and connections.
|
||||
"""
|
||||
if not isinstance(agent_configs, list):
|
||||
raise ValueError("Agents config is not a list in get_agents")
|
||||
if not isinstance(tool_configs, list):
|
||||
raise ValueError("Tools config is not a list in get_agents")
|
||||
|
||||
new_agents = []
|
||||
new_agent_to_children = {}
|
||||
new_agent_name_to_index = {}
|
||||
# Create Agent objects from config
|
||||
for agent_config in agent_configs:
|
||||
logger.debug(f"Processing config for agent: {agent_config['name']}")
|
||||
print("="*100)
|
||||
print(f"Processing config for agent: {agent_config['name']}")
|
||||
|
||||
# If hasRagSources, append the RAG tool to the agent's tools
|
||||
if agent_config.get("hasRagSources", False):
|
||||
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
|
||||
agent_config["tools"].append(rag_tool_name)
|
||||
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
|
||||
|
||||
# Prepare tool lists for this agent
|
||||
external_tools = []
|
||||
|
||||
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
|
||||
print(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
|
||||
|
||||
new_tools = []
|
||||
rag_tool = get_rag_tool(agent_config, complete_request)
|
||||
if rag_tool:
|
||||
new_tools.append(rag_tool)
|
||||
logger.debug(f"Added rag tool to agent {agent_config['name']}")
|
||||
print(f"Added rag tool to agent {agent_config['name']}")
|
||||
|
||||
for tool_name in agent_config["tools"]:
|
||||
|
||||
tool_config = get_tool_config_by_name(tool_configs, tool_name)
|
||||
|
||||
if tool_config:
|
||||
external_tools.append({
|
||||
"type": "function",
|
||||
"function": tool_config
|
||||
})
|
||||
#TODO: Remove this once we have a way to handle the additionalProperties
|
||||
tool_config['parameters']['additionalProperties'] = False
|
||||
tool = FunctionTool(
|
||||
name=tool_name,
|
||||
description=tool_config["description"],
|
||||
params_json_schema=tool_config["parameters"],
|
||||
on_invoke_tool=lambda ctx, args, _tool_name=tool_name, _tool_config=tool_config, _complete_request=complete_request:
|
||||
catch_all(ctx, args, _tool_name, _tool_config, _complete_request)
|
||||
)
|
||||
new_tools.append(tool)
|
||||
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
|
||||
print(f"Added tool {tool_name} to agent {agent_config['name']}")
|
||||
else:
|
||||
logger.warning(f"Tool {tool_name} not found in tool_configs")
|
||||
print(f"WARNING: Tool {tool_name} not found in tool_configs")
|
||||
|
||||
# Create the agent object
|
||||
logger.debug(f"Creating Agent object for {agent_config['name']}")
|
||||
print(f"Creating Agent object for {agent_config['name']}")
|
||||
try:
|
||||
new_agent = NewAgent(
|
||||
name=agent_config["name"],
|
||||
instructions=agent_config["instructions"],
|
||||
handoff_description=agent_config["description"],
|
||||
tools=new_tools,
|
||||
model=agent_config["model"]
|
||||
)
|
||||
new_agent_to_children[agent_config["name"]] = agent_config.get("connectedAgents", [])
|
||||
new_agent_name_to_index[agent_config["name"]] = len(new_agents)
|
||||
new_agents.append(new_agent)
|
||||
logger.debug(f"Successfully created agent: {agent_config['name']}")
|
||||
print(f"Successfully created agent: {agent_config['name']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
|
||||
print(f"ERROR: Failed to create agent {agent_config['name']}: {str(e)}")
|
||||
raise
|
||||
|
||||
for new_agent in new_agents:
|
||||
# Initialize the handoffs attribute if it doesn't exist
|
||||
if not hasattr(new_agent, 'handoffs'):
|
||||
new_agent.handoffs = []
|
||||
# Look up the agent's children from the old agent and create a list called handoffs in new_agent with pointers to the children in new_agents
|
||||
new_agent.handoffs = [new_agents[new_agent_name_to_index[child]] for child in new_agent_to_children[new_agent.name]]
|
||||
|
||||
print("Returning created agents")
|
||||
print("="*100)
|
||||
return new_agents
|
||||
|
||||
|
||||
def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
|
||||
"""
|
||||
Create a Response object with the given parameters.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tokens_used: Dictionary tracking token usage
|
||||
agent: The agent that generated the response
|
||||
error_msg: Error message if any
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
if messages is None:
|
||||
messages = []
|
||||
if tokens_used is None:
|
||||
tokens_used = {}
|
||||
|
||||
return NewResponse(
|
||||
messages=messages,
|
||||
agent=agent,
|
||||
tokens_used=tokens_used,
|
||||
error_msg=error_msg
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
agent,
|
||||
messages,
|
||||
external_tools=None,
|
||||
tokens_used=None
|
||||
):
|
||||
"""
|
||||
Wrapper function for initializing and running the Swarm client.
|
||||
"""
|
||||
logger.info(f"Initializing Swarm client for agent: {agent.name}")
|
||||
print(f"Initializing Swarm client for agent: {agent.name}")
|
||||
|
||||
# Initialize default parameters
|
||||
if external_tools is None:
|
||||
external_tools = []
|
||||
if tokens_used is None:
|
||||
tokens_used = {}
|
||||
|
||||
# Format messages to ensure they're compatible with the OpenAI API
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
formatted_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg["content"]
|
||||
}
|
||||
formatted_messages.append(formatted_msg)
|
||||
else:
|
||||
formatted_messages.append({
|
||||
"role": "user",
|
||||
"content": str(msg)
|
||||
})
|
||||
|
||||
# Create a new event loop for this thread
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the agent with the formatted messages
|
||||
logger.info("Beginning Swarm run with run_sync")
|
||||
print("Beginning Swarm run with run_sync")
|
||||
|
||||
try:
|
||||
response = loop.run_until_complete(Runner.run(agent, formatted_messages))
|
||||
except Exception as e:
|
||||
logger.error(f"Error during run: {str(e)}")
|
||||
print(f"Error during run: {str(e)}")
|
||||
raise
|
||||
|
||||
logger.info(f"Completed Swarm run for agent: {agent.name}")
|
||||
print(f"Completed Swarm run for agent: {agent.name}")
|
||||
return response
|
||||
|
||||
async def run_streamed(
|
||||
agent,
|
||||
messages,
|
||||
external_tools=None,
|
||||
tokens_used=None
|
||||
):
|
||||
"""
|
||||
Wrapper function for initializing and running the Swarm client in streaming mode.
|
||||
"""
|
||||
logger.info(f"Initializing Swarm streaming client for agent: {agent.name}")
|
||||
print(f"Initializing Swarm streaming client for agent: {agent.name}")
|
||||
|
||||
# Initialize default parameters
|
||||
if external_tools is None:
|
||||
external_tools = []
|
||||
if tokens_used is None:
|
||||
tokens_used = {}
|
||||
|
||||
# Format messages to ensure they're compatible with the OpenAI API
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
formatted_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg["content"]
|
||||
}
|
||||
formatted_messages.append(formatted_msg)
|
||||
else:
|
||||
formatted_messages.append({
|
||||
"role": "user",
|
||||
"content": str(msg)
|
||||
})
|
||||
|
||||
logger.info("Beginning Swarm streaming run")
|
||||
print("Beginning Swarm streaming run")
|
||||
|
||||
try:
|
||||
# Use the Runner.run_streamed method
|
||||
stream_result = Runner.run_streamed(agent, formatted_messages)
|
||||
return stream_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during streaming run: {str(e)}")
|
||||
print(f"Error during streaming run: {str(e)}")
|
||||
raise
|
||||
143
apps/rowboat_agents/src/graph/tool_calling.py
Normal file
143
apps/rowboat_agents/src/graph/tool_calling.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
from bson.objectid import ObjectId
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Any
|
||||
from qdrant_client import QdrantClient
|
||||
import json
|
||||
# Initialize MongoDB client
|
||||
mongo_uri = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
|
||||
mongo_client = AsyncIOMotorClient(mongo_uri)
|
||||
db = mongo_client.rowboat
|
||||
data_sources_collection = db['sources']
|
||||
data_source_docs_collection = db['source_docs']
|
||||
|
||||
|
||||
qdrant_client = QdrantClient(url=os.environ.get("QDRANT_URL"))
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Define embedding model
|
||||
embedding_model = "text-embedding-3-small"
|
||||
|
||||
async def embed(model: str, value: str) -> dict:
|
||||
"""
|
||||
Generate embeddings using OpenAI's embedding models.
|
||||
|
||||
Args:
|
||||
model (str): The embedding model to use (e.g., "text-embedding-3-small").
|
||||
value (str): The text to embed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the embedding.
|
||||
"""
|
||||
response = client.embeddings.create(
|
||||
model=model,
|
||||
input=value
|
||||
)
|
||||
return {"embedding": response.data[0].embedding}
|
||||
|
||||
async def call_rag_tool(
|
||||
project_id: str,
|
||||
query: str,
|
||||
source_ids: list[str],
|
||||
return_type: str,
|
||||
k: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Runs the RAG tool call to retrieve information based on the query and source IDs.
|
||||
|
||||
Args:
|
||||
project_id (str): The ID of the project.
|
||||
query (str): The query string to search for.
|
||||
source_ids (list[str]): List of source IDs to filter the search.
|
||||
return_type (str): The type of return, e.g., 'chunks' or other.
|
||||
k (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the results of the search.
|
||||
"""
|
||||
|
||||
print("\n\n calling rag tool \n\n")
|
||||
print(query)
|
||||
# Create embedding for the query
|
||||
embed_result = await embed(model=embedding_model, value=query)
|
||||
|
||||
print(embed_result)
|
||||
# Fetch all active data sources for this project
|
||||
sources = await data_sources_collection.find({
|
||||
"projectId": project_id,
|
||||
"active": True
|
||||
}).to_list(length=None)
|
||||
|
||||
print(sources)
|
||||
# Filter sources to those in source_ids
|
||||
valid_source_ids = [
|
||||
str(s["_id"]) for s in sources if str(s["_id"]) in source_ids
|
||||
]
|
||||
|
||||
print(valid_source_ids)
|
||||
# If no valid sources are found, return empty results
|
||||
if not valid_source_ids:
|
||||
return ''
|
||||
|
||||
# Perform Qdrant vector search
|
||||
qdrant_results = qdrant_client.search(
|
||||
collection_name="embeddings",
|
||||
query_vector=embed_result["embedding"],
|
||||
query_filter={
|
||||
"must": [
|
||||
{"key": "projectId", "match": {"value": project_id}},
|
||||
{"key": "sourceId", "match": {"any": valid_source_ids}},
|
||||
]
|
||||
},
|
||||
limit=k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# Map the Qdrant results to the desired format
|
||||
results = [
|
||||
{
|
||||
"title": point.payload["title"],
|
||||
"name": point.payload["name"],
|
||||
"content": point.payload["content"],
|
||||
"docId": point.payload["docId"],
|
||||
"sourceId": point.payload["sourceId"],
|
||||
}
|
||||
for point in qdrant_results
|
||||
]
|
||||
|
||||
print(return_type)
|
||||
print(results)
|
||||
# If return_type is 'chunks', return the results directly
|
||||
if return_type == "chunks":
|
||||
return json.dumps({"Information": results}, indent=2)
|
||||
|
||||
# Otherwise, fetch the full document contents from MongoDB
|
||||
doc_ids = [ObjectId(r["docId"]) for r in results]
|
||||
docs = await data_source_docs_collection.find({"_id": {"$in": doc_ids}}).to_list(length=None)
|
||||
|
||||
# Create a dictionary for quick lookup of documents by their string ID
|
||||
doc_dict = {str(doc["_id"]): doc for doc in docs}
|
||||
|
||||
# Update the results with the full document content
|
||||
results = [
|
||||
{**r, "content": doc_dict.get(r["docId"], {}).get("content", "")}
|
||||
for r in results
|
||||
]
|
||||
|
||||
# Convert results to a JSON string
|
||||
formatted_string = json.dumps({"Information": results}, indent=2)
|
||||
print(formatted_string)
|
||||
return formatted_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(call_rag_tool(
|
||||
project_id="faf2bfb3-41d4-4299-b0d2-048581ea9bd8",
|
||||
query="What is the range on your scooter",
|
||||
source_ids=["67e102c9fab4514d7aaeb5a4"],
|
||||
return_type="docs",
|
||||
k=3))
|
||||
81
apps/rowboat_agents/src/graph/tools.py
Normal file
81
apps/rowboat_agents/src/graph/tools.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
|
||||
RAG_TOOL = {
|
||||
"name": "getArticleInfo",
|
||||
"type": "rag",
|
||||
"description": "Fetch articles with knowledge relevant to the query",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The query to retrieve articles for"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"query"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
CLOSE_CHAT_TOOL = {
|
||||
"name": "close_chat",
|
||||
"type": "close_chat",
|
||||
"description": "Close the chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error_message": {
|
||||
"type": "string", "description": "The error message to close the chat with"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def tool_raise_error(error_message):
|
||||
logger.error(f"Raising error: {error_message}")
|
||||
raise ValueError(f"Raising error: {error_message}")
|
||||
|
||||
def respond_to_tool_raise_error(tool_calls, mock=False):
|
||||
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
|
||||
return _create_tool_response(tool_calls, tool_raise_error(error_message))
|
||||
|
||||
def tool_close_chat(error_message):
|
||||
logger.error(f"Closing chat: {error_message}")
|
||||
raise ValueError(f"Closing chat: {error_message}")
|
||||
|
||||
def respond_to_tool_close_chat(tool_calls, mock=False):
|
||||
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
|
||||
return _create_tool_response(tool_calls, tool_close_chat(error_message))
|
||||
|
||||
def _create_tool_response(tool_calls, content, mock=False):
|
||||
"""
|
||||
Creates a standardized tool response format.
|
||||
"""
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": tool_calls[0]["id"],
|
||||
"name": tool_calls[0]["function"]["name"]
|
||||
}
|
||||
|
||||
def create_error_tool_call(error_message):
|
||||
error_message_tool_call = {
|
||||
"role": "assistant",
|
||||
"sender": "system",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "raise_error",
|
||||
"arguments": "{\"error_message\":\"" + error_message + "\"}"
|
||||
},
|
||||
"id": "call_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=24)),
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
}
|
||||
return error_message_tool_call
|
||||
18
apps/rowboat_agents/src/graph/types.py
Normal file
18
apps/rowboat_agents/src/graph/types.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from enum import Enum
|
||||
class AgentRole(Enum):
|
||||
ESCALATION = "escalation"
|
||||
POST_PROCESSING = "post_process"
|
||||
GUARDRAILS = "guardrails"
|
||||
|
||||
class ControlType(Enum):
|
||||
RETAIN = "retain"
|
||||
PARENT_AGENT = "relinquish_to_parent"
|
||||
START_AGENT = "relinquish_to_start"
|
||||
|
||||
class PromptType(Enum):
|
||||
STYLE = "style_prompt"
|
||||
GREETING = "greeting"
|
||||
|
||||
class ErrorType(Enum):
|
||||
FATAL = "fatal"
|
||||
ESCALATE = "escalate"
|
||||
0
apps/rowboat_agents/src/utils/__init__.py
Normal file
0
apps/rowboat_agents/src/utils/__init__.py
Normal file
200
apps/rowboat_agents/src/utils/common.py
Normal file
200
apps/rowboat_agents/src/utils/common.py
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def setup_logger(name, log_file='./run.log', level=logging.INFO, log_to_file=True):
|
||||
"""Function to set up a logger with a specific name and log file."""
|
||||
formatter = logging.Formatter('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s')
|
||||
|
||||
if log_to_file:
|
||||
handler = logging.FileHandler(log_file)
|
||||
else:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Create a logger and set its level
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# Clear any existing handlers to avoid duplicates
|
||||
if logger.hasHandlers():
|
||||
logger.handlers.clear()
|
||||
|
||||
# Prevent propagation to parent loggers
|
||||
logger.propagate = False
|
||||
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
common_logger = setup_logger('logger')
|
||||
logger = common_logger
|
||||
|
||||
def read_json_from_file(file_name):
|
||||
logger.info(f"Reading json from {file_name}")
|
||||
try:
|
||||
with open(file_name, 'r') as file:
|
||||
out = file.read()
|
||||
out = json.loads(out)
|
||||
return out
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
def get_api_key(key_name):
|
||||
api_key = os.getenv(key_name)
|
||||
# Check if the API key was loaded successfully
|
||||
if not api_key:
|
||||
raise ValueError(f"{key_name} not found. Did you set it in the .env file?")
|
||||
return api_key
|
||||
|
||||
openai_client = OpenAI(
|
||||
api_key=get_api_key("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
def generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json', model="gpt-4o"):
|
||||
return generate_openai_output(messages, output_type, model)
|
||||
|
||||
def generate_openai_output(messages, output_type='not_json', model="gpt-4o", return_completion=False):
|
||||
try:
|
||||
if output_type == 'json':
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
else:
|
||||
chat_completion = openai_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if return_completion:
|
||||
return chat_completion
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
def generate_llm_output(messages, model):
|
||||
model_provider = None
|
||||
if "gpt" in model:
|
||||
model_provider = "openai"
|
||||
else:
|
||||
raise ValueError(f"Model {model} not supported")
|
||||
|
||||
if model_provider == "openai":
|
||||
response = generate_openai_output(messages, output_type='text', model=model)
|
||||
return response
|
||||
|
||||
def generate_gpt4o_output_from_multi_turn_conv_multithreaded(messages, retries=5, delay=1, output_type='json'):
|
||||
while retries > 0:
|
||||
try:
|
||||
# Call GPT-4o API
|
||||
output = generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json')
|
||||
return output # If the request is successful, break out of the loop
|
||||
except openai.RateLimitError:
|
||||
print(f'Rate limit exceeded. Retrying in {delay} seconds...')
|
||||
time.sleep(delay)
|
||||
delay *= 2 # Exponential backoff
|
||||
retries -= 1
|
||||
|
||||
if retries == 0:
|
||||
print(f'Failed to process due to rate limit.')
|
||||
return []
|
||||
|
||||
def convert_message_content_json_to_strings(messages):
|
||||
for msg in messages:
|
||||
if 'content' in msg.keys() and isinstance(msg['content'], dict):
|
||||
msg['content'] = json.dumps(msg['content'])
|
||||
return messages
|
||||
|
||||
def merge_defaultdicts(dict_parent, dict_child):
|
||||
for key, value in dict_child.items():
|
||||
if key in dict_parent:
|
||||
# If the key exists in both, handle merging based on type
|
||||
if isinstance(dict_parent[key], list):
|
||||
dict_parent[key].extend(value)
|
||||
elif isinstance(dict_parent[key], dict):
|
||||
dict_parent[key].update(value)
|
||||
elif isinstance(dict_parent[key], set):
|
||||
dict_parent[key].update(value)
|
||||
else:
|
||||
dict_parent[key] += value # For other types like int, float, etc.
|
||||
else:
|
||||
dict_parent[key] = value
|
||||
|
||||
return dict_parent
|
||||
|
||||
def read_jsonl_from_file(file_name):
|
||||
# logger.info(f"Reading jsonl from {file_name}")
|
||||
try:
|
||||
with open(file_name, 'r') as file:
|
||||
lines = file.readlines()
|
||||
dataset = [json.loads(line.strip()) for line in lines]
|
||||
return dataset
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
def write_jsonl_to_file(list_dicts, file_name):
|
||||
try:
|
||||
with open(file_name, 'w') as file:
|
||||
for d in list_dicts:
|
||||
file.write(json.dumps(d)+'\n')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return False
|
||||
|
||||
def read_text_from_file(file_name):
|
||||
try:
|
||||
with open(file_name, 'r') as file:
|
||||
out = file.read()
|
||||
return out
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
def write_json_to_file(data, file_name):
|
||||
try:
|
||||
with open(file_name, 'w') as file:
|
||||
json.dump(data, file, indent=4)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return False
|
||||
|
||||
|
||||
def get_git_path(path):
|
||||
# Run `git rev-parse --show-toplevel` to get the root of the Git repository
|
||||
try:
|
||||
git_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
|
||||
return f"{git_root}/{path}"
|
||||
except subprocess.CalledProcessError:
|
||||
raise RuntimeError("Not inside a Git repository")
|
||||
|
||||
def update_tokens_used(provider, model, tokens_used, completion):
|
||||
provider_model = f"{provider}/{model}"
|
||||
input_tokens = completion.usage.prompt_tokens
|
||||
output_tokens = completion.usage.completion_tokens
|
||||
|
||||
if provider_model not in tokens_used:
|
||||
tokens_used[provider_model] = {
|
||||
'input_tokens': 0,
|
||||
'output_tokens': 0,
|
||||
}
|
||||
|
||||
tokens_used[provider_model]['input_tokens'] += input_tokens
|
||||
tokens_used[provider_model]['output_tokens'] += output_tokens
|
||||
|
||||
return tokens_used
|
||||
Loading…
Add table
Add a link
Reference in a new issue