rowboat/apps/agents/src/graph/core.py

561 lines
24 KiB
Python
Raw Normal View History

import os
import sys
from copy import deepcopy
from src.swarm.types import Agent
from src.swarm.core import Swarm
from .guardrails import post_process_response
from .tools import create_error_tool_call
from .types import AgentRole, PromptType, ErrorType
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
2025-03-05 18:41:42 +05:30
from src.swarm.types import Response
from datetime import datetime
from src.utils.common import common_logger
logger = common_logger
def order_messages(messages):
# Arrange keys in specified order
ordered_messages = []
for msg in messages:
ordered = {}
msg = {k: v for k, v in msg.items() if v is not None}
# Add keys in specified order if they exist
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
if key in msg:
ordered[key] = msg[key]
# Add remaining keys in alphabetical order
for key in sorted(msg.keys()):
if key not in ['role', 'sender', 'content', 'created_at', 'timestamp']:
ordered[key] = msg[key]
ordered_messages.append(ordered)
return ordered_messages
def clean_up_history(agent_data):
for data in agent_data:
data["history"] = order_messages(data["history"])
return agent_data
def clear_agent_fields(agent):
agent.children = {}
agent.parent_function = None
agent.candidate_parent_functions = {}
agent.child_functions = {}
if agent.most_recent_parent:
agent.history = []
return agent
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg):
# Create Agent objects
agents = []
if not isinstance(agent_configs, list):
raise ValueError("Agents config is not a list in get_agents")
if not isinstance(tool_configs, list):
raise ValueError("Tools config is not a list in get_agents")
for agent_config in agent_configs:
logger.debug(f"Processing config for agent: {agent_config['name']}")
# Get tools for this agent
external_tools = []
internal_tools = []
candidate_parent_functions = {}
child_functions = {}
logger.debug(f"Finding tools for agent {agent_config['name']}")
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
if agent_config.get("hasRagSources", False):
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
agent_config["tools"].append(rag_tool_name)
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
for tool_name in agent_config["tools"]:
logger.debug(f"Looking for tool config: {tool_name}")
tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config:
if tool_name in available_tool_mappings:
internal_tools.append(available_tool_mappings[tool_name])
else:
external_tools.append({
"type": "function",
"function": tool_config
})
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
logger.warning(f"Tool {tool_name} not found in tool_configs")
history = []
this_agent_data = get_agent_data_by_name(agent_config["name"], agent_data)
if this_agent_data:
if localize_history:
history = this_agent_data.get("history", [])
# Create agent
logger.debug(f"Creating Agent object for {agent_config['name']}")
logger.debug(f"Using model: {agent_config['model']}")
logger.debug(f"Number of tools being added: Internal - {len(internal_tools)} | External - {len(external_tools)}")
try:
agent = Agent(
name=agent_config["name"],
type=agent_config.get("type", "default"),
instructions=agent_config["instructions"],
description=agent_config.get("description", ""),
internal_tools=internal_tools,
external_tools=external_tools,
candidate_parent_functions=candidate_parent_functions,
child_functions=child_functions,
model=agent_config["model"],
respond_to_user=agent_config.get("respond_to_user", False),
history=history,
children_names=agent_config.get("connectedAgents", []),
most_recent_parent=None
)
agents.append(agent)
logger.debug(f"Successfully created agent: {agent_config['name']}")
except Exception as e:
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
raise
# Adding most recent parents to agents
for agent in agents:
most_recent_parent = None
this_agent_data = get_agent_data_by_name(agent.name, agent_data)
if this_agent_data:
most_recent_parent_name = this_agent_data.get("most_recent_parent_name", "")
if most_recent_parent_name:
most_recent_parent = get_agent_by_name(most_recent_parent_name, agents) if most_recent_parent_name else None
if most_recent_parent:
agent.most_recent_parent = most_recent_parent
# Adding children agents to parent agents
logger.info("Adding children agents to parent agents")
for agent in agents:
agent.children = {agent_.name: agent_ for agent_ in agents if agent_.name in agent.children_names}
# Generate transfer functions for transferring to children agents
logger.info("Generating transfer functions for transferring to children agents")
transfer_functions = {
agent.name: create_transfer_function_to_agent(agent)
for agent in agents
}
# Add transfer functions for parents to transfer to children
logger.info("Adding transfer functions for parents to transfer to children")
for agent in agents:
for child in agent.children.values():
agent.child_functions[child.name] = transfer_functions[child.name]
# Add transfer-related instructions to parent agents
logger.info("Adding child transfer-related instructions to parent agents")
for agent in agents:
if agent.children:
agent = add_transfer_instructions_to_parent_agents(agent, agent.children, transfer_functions)
# Generate and append duplicate transfer functions for children to transfer to parent agents
logger.info("Generating duplicate transfer functions for children to transfer to parent agents")
for agent in agents:
for child in agent.children.values():
func = create_transfer_function_to_parent_agent(
parent_agent=agent,
children_aware_of_parent=children_aware_of_parent,
transfer_functions=transfer_functions
)
child.candidate_parent_functions[agent.name] = func
for agent in agents:
if agent.candidate_parent_functions and agent.type != "escalation":
agent = add_transfer_instructions_to_child_agents(
child=agent,
children_aware_of_parent=children_aware_of_parent
)
for agent in agents:
if agent.most_recent_parent:
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
for agent in agents:
agent = add_universal_system_message_to_agent(agent, universal_sys_msg)
return agents
def check_request_validity(messages, agent_configs, tool_configs, prompt_configs, max_overall_turns):
error_msg = ""
error_type = ErrorType.ESCALATE.value
# Limits checks
external_messages_count = sum(1 for msg in messages if msg.get("response_type") == "external")
if external_messages_count >= max_overall_turns:
error_msg = f"Max overall turns reached: {max_overall_turns}"
# Empty checks
if not messages:
error_msg = "Messages list is empty"
# Empty checks --> Fatal
if not agent_configs:
error_msg = "Agent configs list is empty"
error_type = ErrorType.FATAL.value
# Type checks --> Fatal
for arg in [messages, agent_configs, tool_configs, prompt_configs]:
if not isinstance(arg, list):
error_msg = f"{arg} is not a list"
error_type = ErrorType.FATAL.value
# Post processing agent, guardrails and escalation agent check - there should be at max one agent with type "post_processing_agent", "guardrails_agent" and "escalation_agent" respectively --> Fatal
post_processing_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.POST_PROCESSING.value)
guardrails_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.GUARDRAILS.value)
escalation_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.ESCALATION.value)
if post_processing_agent_count > 1 or guardrails_agent_count > 1 or escalation_agent_count > 1:
error_msg = "Invalid post processing agent or guardrails agent count - expected at most 1"
error_type = ErrorType.FATAL.value
# All agent config should have: name, instructions, model --> Fatal
for agent_config in agent_configs:
if not all(key in agent_config for key in ["name", "instructions", "model"]):
missing_keys = [key for key in ["name", "instructions", "tools", "model"] if key not in agent_config]
error_msg = f"Invalid agent config - missing keys: {missing_keys}"
error_type = ErrorType.FATAL.value
# All tool configs should have: name, parameters --> Fatal
for tool_config in tool_configs:
if not all(key in tool_config for key in ["name", "parameters"]):
missing_keys = [key for key in ["name", "parameters"] if key not in tool_config]
error_msg = f"Invalid tool config - missing keys: {missing_keys}"
error_type = ErrorType.FATAL.value
# Check for cycles in the agent config graph. Raise error if cycle is found, along with the agents involved in the cycle.
def find_cycles(agent_name, agent_configs, visited=None, path=None):
if visited is None:
visited = set()
if path is None:
path = []
visited.add(agent_name)
path.append(agent_name)
agent_config = get_agent_config_by_name(agent_name, agent_configs)
if not agent_config:
return None
for child_name in agent_config.get("connectedAgents", []):
if child_name in path:
cycle = path[path.index(child_name):]
cycle.append(child_name)
return cycle
if child_name not in visited:
cycle = find_cycles(child_name, agent_configs, visited, path)
if cycle:
return cycle
path.pop()
return None
for agent_config in agent_configs:
if agent_config.get("name") in agent_config.get("connectedAgents", []):
error_msg = f"Cycle detected in agent config graph - agent {agent_config.get('name')} is connected to itself"
cycle = find_cycles(agent_config.get("name"), agent_configs)
if cycle:
cycle_str = " -> ".join(cycle)
error_msg = f"Cycle detected in agent config graph: {cycle_str}"
return error_msg, error_type
def handle_error(error_tool_call, error_msg, return_diff_messages, messages, turn_messages, state, tokens_used):
resp_messages = turn_messages if return_diff_messages else messages + turn_messages
resp_messages.extend([create_error_tool_call(error_msg)])
if error_tool_call:
return resp_messages, tokens_used, state
else:
raise ValueError(error_msg)
def create_final_response(response, turn_messages, messages, tokens_used, all_agents, return_diff_messages):
response.messages = turn_messages if return_diff_messages else messages + turn_messages
response.tokens_used = tokens_used
new_state = construct_state_from_response(response, all_agents)
return response.messages, response.tokens_used, new_state
def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={}, localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False, children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[], error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4, escalate_errors=True, max_overall_turns=10):
greeting_turn = True if not any(msg.get("role") != "system" for msg in messages) else False
logger.info("Running stateless turn")
turn_messages = []
tokens_used = {}
messages = order_messages(messages)
tool_configs = tool_configs + additional_tool_configs
validation_error_msg, validation_error_type = check_request_validity(
messages=messages,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
max_overall_turns=max_overall_turns
)
if validation_error_msg and validation_error_type == ErrorType.FATAL.value:
logger.error(validation_error_msg)
return handle_error(
error_tool_call=error_tool_call,
error_msg=validation_error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
post_processing_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.POST_PROCESSING.value)
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
agent_data = state.get("agent_data", [])
universal_sys_msg = ""
if not greeting_turn:
latest_assistant_msg = get_latest_assistant_msg(messages)
universal_sys_msg = get_universal_system_message(messages)
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
msg_type = latest_non_assistant_msgs[-1]["role"]
last_agent_name = get_last_agent_name(
state=state,
agent_configs=agent_configs,
start_agent_name=start_agent_name,
msg_type=msg_type,
latest_assistant_msg=latest_assistant_msg,
start_turn_with_start_agent=start_turn_with_start_agent
)
logger.info("Localizing message history")
if msg_type == "user":
messages = reset_current_turn(messages)
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
agent_data = clean_up_history(agent_data)
agent_data = add_recent_messages_to_history(
recent_messages=latest_non_assistant_msgs,
last_agent_name=last_agent_name,
agent_data=agent_data,
messages=messages,
parent_has_child_history=parent_has_child_history
)
state["agent_data"] = agent_data
logger.info("Initializing agents")
all_agents = get_agents(
agent_configs=agent_configs,
tool_configs=tool_configs,
available_tool_mappings=available_tool_mappings,
agent_data=state.get("agent_data", []),
localize_history=localize_history,
start_turn_with_start_agent=start_turn_with_start_agent,
children_aware_of_parent=children_aware_of_parent,
universal_sys_msg=universal_sys_msg
)
if not all_agents:
logger.error("No agents initialized")
return handle_error(
error_tool_call=error_tool_call,
error_msg="No agents initialized"
)
if greeting_turn:
greeting_msg = get_prompt_by_type(prompt_configs, PromptType.GREETING.value)
if not greeting_msg:
logger.error("Greeting prompt not found and messages is empty")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Greeting prompt not found and messages is empty",
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
greeting_msg_internal = {
"content": greeting_msg,
"role": "assistant",
"sender": start_agent_name,
"response_type": "internal",
"created_at": datetime.now().isoformat(),
"current_turn": True
}
greeting_msg_external = deepcopy(greeting_msg_internal)
greeting_msg_external["response_type"] = "external"
greeting_msg_external["sender"] = greeting_msg_external["sender"] + ' >> External'
turn_messages.extend([greeting_msg_internal, greeting_msg_external])
response = Response(
messages=turn_messages,
tokens_used={},
agent=get_agent_by_name(start_agent_name, all_agents),
error_msg=''
)
return create_final_response(
response=response,
turn_messages=turn_messages,
messages=messages,
tokens_used=tokens_used,
all_agents=all_agents,
return_diff_messages=return_diff_messages
)
error_escalation_agent = deepcopy(get_agent_by_type(all_agents, AgentRole.ESCALATION.value))
if not error_escalation_agent:
logger.error("Escalation agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Escalation agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
error_escalation_agent = clear_agent_fields(error_escalation_agent)
error_escalation_agent = add_error_escalation_instructions(error_escalation_agent)
logger.info(f"Initialized {len(all_agents)} agents")
logger.debug("Getting last agent")
last_agent = get_agent_by_name(last_agent_name, all_agents)
if not last_agent:
logger.error("Last agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Last agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
state=state
)
external_tools = get_external_tools(tool_configs)
logger.info(f"Found {len(external_tools)} external tools")
logger.debug("Initializing Swarm client")
swarm_client = Swarm()
if not validation_error_msg:
response = swarm_client.run(
agent=last_agent,
messages=messages,
execute_tools=True,
external_tools=external_tools,
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info(f"Completed run of agent: {last_agent.name}")
if validation_error_msg and validation_error_type == ErrorType.ESCALATE.value or response.error_msg:
logger.info(f"Error raised in turn: {response.error_msg}")
response_sender_agent_name = response.agent.name
if escalate_errors and response_sender_agent_name != error_escalation_agent.name:
response = client.run(
agent=error_escalation_agent,
messages=[],
execute_tools=True,
external_tools=external_tools,
localize_history=False,
parent_has_child_history=False,
max_messages_per_turn=max_messages_per_error_escalation_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info(f"Completed run of escalation agent: {error_escalation_agent.name}")
if response.error_msg:
logger.info(f"Error raised in escalation turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
else:
logger.info(f"Error raised in turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
if post_processing_agent_config:
response = post_process_response(
messages=turn_messages,
post_processing_agent_name=post_processing_agent_config.get("name", "Post Processing agent"),
post_process_instructions=post_processing_agent_config.get("instructions", ""),
style_prompt=get_prompt_by_type(prompt_configs, PromptType.STYLE.value),
context='',
model=post_processing_agent_config.get("model", "gpt-4o"),
tokens_used=tokens_used,
last_agent=last_agent
)
tokens_used = response.tokens_used
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info("Response post-processed")
2025-03-05 18:41:42 +05:30
else:
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
duplicate_msg = deepcopy(turn_messages[-1])
duplicate_msg["response_type"] = "external"
duplicate_msg["sender"] = duplicate_msg["sender"] + ' >> External'
response = Response(
messages=[duplicate_msg],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info("Last response duplicated and set to external")
if guardrails_agent_config:
logger.info("Guardrails agent not implemented (ignoring)")
pass
if not state or not state.get("last_agent_name"):
logger.error("State is empty or last agent name is not set")
raise ValueError("State is empty or last agent name is not set")
return create_final_response(
response=response,
turn_messages=turn_messages,
messages=messages,
tokens_used=tokens_used,
all_agents=all_agents,
return_diff_messages=return_diff_messages
)