refactoring

This commit is contained in:
arkml 2025-03-14 13:04:28 +05:30 committed by Ramnique Singh
parent aab6a28006
commit 0e31098d58
3 changed files with 569 additions and 493 deletions

View file

@ -1,336 +1,127 @@
import os
import sys
from copy import deepcopy
from src.swarm.types import Agent
from src.swarm.core import Swarm
from .guardrails import post_process_response
from .tools import create_error_tool_call
from .types import AgentRole, PromptType, ErrorType
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent
import logging
from .types import AgentRole
from .helpers.access import (
get_agent_by_name,
get_external_tools, pop_agent_config_by_type
)
from .helpers.state import (
add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
)
from .helpers.instructions import (
get_universal_system_message
)
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from src.swarm.types import Response
from datetime import datetime
from .swarm_wrapper import run as swarm_run, create_response, get_agents
# Create a dedicated logger for swarm wrapper
logger = logging.getLogger("graph")
logger.setLevel(logging.INFO)
from src.utils.common import common_logger
logger = common_logger
def order_messages(messages):
# Arrange keys in specified order
"""
Sorts each message's keys in a specified order and returns a new list of ordered messages.
"""
ordered_messages = []
for msg in messages:
ordered = {}
# Filter out None values
msg = {k: v for k, v in msg.items() if v is not None}
# Add keys in specified order if they exist
# Specify the exact order
ordered = {}
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
if key in msg:
ordered[key] = msg[key]
# Add remaining keys in alphabetical order
for key in sorted(msg.keys()):
if key not in ['role', 'sender', 'content', 'created_at', 'timestamp']:
ordered[key] = msg[key]
ordered_messages.append(ordered)
# Add remaining keys in alphabetical order
remaining_keys = sorted(k for k in msg if k not in ordered)
for key in remaining_keys:
ordered[key] = msg[key]
ordered_messages.append(ordered)
return ordered_messages
def clean_up_history(agent_data):
"""
Ensures each agent's history is sorted using order_messages.
"""
for data in agent_data:
data["history"] = order_messages(data["history"])
return agent_data
def clear_agent_fields(agent):
agent.children = {}
agent.parent_function = None
agent.candidate_parent_functions = {}
agent.child_functions = {}
if agent.most_recent_parent:
agent.history = []
return agent
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg):
# Create Agent objects
agents = []
if not isinstance(agent_configs, list):
raise ValueError("Agents config is not a list in get_agents")
if not isinstance(tool_configs, list):
raise ValueError("Tools config is not a list in get_agents")
for agent_config in agent_configs:
logger.debug(f"Processing config for agent: {agent_config['name']}")
# Get tools for this agent
external_tools = []
internal_tools = []
candidate_parent_functions = {}
child_functions = {}
logger.debug(f"Finding tools for agent {agent_config['name']}")
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
if agent_config.get("hasRagSources", False):
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
agent_config["tools"].append(rag_tool_name)
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
for tool_name in agent_config["tools"]:
logger.debug(f"Looking for tool config: {tool_name}")
tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config:
if tool_name in available_tool_mappings:
internal_tools.append(available_tool_mappings[tool_name])
else:
external_tools.append({
"type": "function",
"function": tool_config
})
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
logger.warning(f"Tool {tool_name} not found in tool_configs")
history = []
this_agent_data = get_agent_data_by_name(agent_config["name"], agent_data)
if this_agent_data:
if localize_history:
history = this_agent_data.get("history", [])
# Create agent
logger.debug(f"Creating Agent object for {agent_config['name']}")
logger.debug(f"Using model: {agent_config['model']}")
logger.debug(f"Number of tools being added: Internal - {len(internal_tools)} | External - {len(external_tools)}")
try:
agent = Agent(
name=agent_config["name"],
type=agent_config.get("type", "default"),
instructions=agent_config["instructions"],
description=agent_config.get("description", ""),
internal_tools=internal_tools,
external_tools=external_tools,
candidate_parent_functions=candidate_parent_functions,
child_functions=child_functions,
model=agent_config["model"],
respond_to_user=agent_config.get("respond_to_user", False),
history=history,
children_names=agent_config.get("connectedAgents", []),
most_recent_parent=None
)
agents.append(agent)
logger.debug(f"Successfully created agent: {agent_config['name']}")
except Exception as e:
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
raise
# Adding most recent parents to agents
for agent in agents:
most_recent_parent = None
this_agent_data = get_agent_data_by_name(agent.name, agent_data)
if this_agent_data:
most_recent_parent_name = this_agent_data.get("most_recent_parent_name", "")
if most_recent_parent_name:
most_recent_parent = get_agent_by_name(most_recent_parent_name, agents) if most_recent_parent_name else None
if most_recent_parent:
agent.most_recent_parent = most_recent_parent
# Adding children agents to parent agents
logger.info("Adding children agents to parent agents")
for agent in agents:
agent.children = {agent_.name: agent_ for agent_ in agents if agent_.name in agent.children_names}
# Generate transfer functions for transferring to children agents
logger.info("Generating transfer functions for transferring to children agents")
transfer_functions = {
agent.name: create_transfer_function_to_agent(agent)
for agent in agents
}
# Add transfer functions for parents to transfer to children
logger.info("Adding transfer functions for parents to transfer to children")
for agent in agents:
for child in agent.children.values():
agent.child_functions[child.name] = transfer_functions[child.name]
# Add transfer-related instructions to parent agents
logger.info("Adding child transfer-related instructions to parent agents")
for agent in agents:
if agent.children:
agent = add_transfer_instructions_to_parent_agents(agent, agent.children, transfer_functions)
# Generate and append duplicate transfer functions for children to transfer to parent agents
logger.info("Generating duplicate transfer functions for children to transfer to parent agents")
for agent in agents:
for child in agent.children.values():
func = create_transfer_function_to_parent_agent(
parent_agent=agent,
children_aware_of_parent=children_aware_of_parent,
transfer_functions=transfer_functions
)
child.candidate_parent_functions[agent.name] = func
for agent in agents:
if agent.candidate_parent_functions and agent.type != "escalation":
agent = add_transfer_instructions_to_child_agents(
child=agent,
children_aware_of_parent=children_aware_of_parent
)
for agent in agents:
if agent.most_recent_parent:
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
for agent in agents:
agent = add_universal_system_message_to_agent(agent, universal_sys_msg)
return agents
def check_request_validity(messages, agent_configs, tool_configs, prompt_configs, max_overall_turns):
error_msg = ""
error_type = ErrorType.ESCALATE.value
# Limits checks
external_messages_count = sum(1 for msg in messages if msg.get("response_type") == "external")
if external_messages_count >= max_overall_turns:
error_msg = f"Max overall turns reached: {max_overall_turns}"
# Empty checks
if not messages:
error_msg = "Messages list is empty"
# Empty checks --> Fatal
if not agent_configs:
error_msg = "Agent configs list is empty"
error_type = ErrorType.FATAL.value
# Type checks --> Fatal
for arg in [messages, agent_configs, tool_configs, prompt_configs]:
if not isinstance(arg, list):
error_msg = f"{arg} is not a list"
error_type = ErrorType.FATAL.value
# Post processing agent, guardrails and escalation agent check - there should be at max one agent with type "post_processing_agent", "guardrails_agent" and "escalation_agent" respectively --> Fatal
post_processing_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.POST_PROCESSING.value)
guardrails_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.GUARDRAILS.value)
escalation_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.ESCALATION.value)
if post_processing_agent_count > 1 or guardrails_agent_count > 1 or escalation_agent_count > 1:
error_msg = "Invalid post processing agent or guardrails agent count - expected at most 1"
error_type = ErrorType.FATAL.value
# All agent config should have: name, instructions, model --> Fatal
for agent_config in agent_configs:
if not all(key in agent_config for key in ["name", "instructions", "model"]):
missing_keys = [key for key in ["name", "instructions", "tools", "model"] if key not in agent_config]
error_msg = f"Invalid agent config - missing keys: {missing_keys}"
error_type = ErrorType.FATAL.value
# All tool configs should have: name, parameters --> Fatal
for tool_config in tool_configs:
if not all(key in tool_config for key in ["name", "parameters"]):
missing_keys = [key for key in ["name", "parameters"] if key not in tool_config]
error_msg = f"Invalid tool config - missing keys: {missing_keys}"
error_type = ErrorType.FATAL.value
# Check for cycles in the agent config graph. Raise error if cycle is found, along with the agents involved in the cycle.
def find_cycles(agent_name, agent_configs, visited=None, path=None):
if visited is None:
visited = set()
if path is None:
path = []
visited.add(agent_name)
path.append(agent_name)
agent_config = get_agent_config_by_name(agent_name, agent_configs)
if not agent_config:
return None
for child_name in agent_config.get("connectedAgents", []):
if child_name in path:
cycle = path[path.index(child_name):]
cycle.append(child_name)
return cycle
if child_name not in visited:
cycle = find_cycles(child_name, agent_configs, visited, path)
if cycle:
return cycle
path.pop()
return None
for agent_config in agent_configs:
if agent_config.get("name") in agent_config.get("connectedAgents", []):
error_msg = f"Cycle detected in agent config graph - agent {agent_config.get('name')} is connected to itself"
cycle = find_cycles(agent_config.get("name"), agent_configs)
if cycle:
cycle_str = " -> ".join(cycle)
error_msg = f"Cycle detected in agent config graph: {cycle_str}"
return error_msg, error_type
def handle_error(error_tool_call, error_msg, return_diff_messages, messages, turn_messages, state, tokens_used):
resp_messages = turn_messages if return_diff_messages else messages + turn_messages
resp_messages.extend([create_error_tool_call(error_msg)])
if error_tool_call:
return resp_messages, tokens_used, state
else:
raise ValueError(error_msg)
def create_final_response(response, turn_messages, messages, tokens_used, all_agents, return_diff_messages):
"""
Constructs the final response data (messages, tokens_used, updated state) that a caller would need.
"""
# Ensure response has a messages attribute
if not hasattr(response, 'messages'):
response.messages = []
# Assign the appropriate messages to the response
response.messages = turn_messages if return_diff_messages else messages + turn_messages
# Ensure tokens_used is a valid dictionary
if not isinstance(tokens_used, dict):
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
# Ensure response has a tokens_used attribute that's a dictionary
if not hasattr(response, 'tokens_used') or not isinstance(response.tokens_used, dict):
response.tokens_used = {}
response.tokens_used = tokens_used
# Ensure response has an agent attribute for state construction
if not hasattr(response, 'agent'):
if all_agents and len(all_agents) > 0:
response.agent = all_agents[0] # Set default agent if missing
new_state = construct_state_from_response(response, all_agents)
return response.messages, response.tokens_used, new_state
def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={}, localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False, children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[], error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4, escalate_errors=True, max_overall_turns=10):
greeting_turn = True if not any(msg.get("role") != "system" for msg in messages) else False
def run_turn(
messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={},
localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False,
children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[],
error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4,
escalate_errors=True, max_overall_turns=10
):
"""
Coordinates a single 'turn' of conversation or processing among agents.
Includes validation, agent setup, optional greeting logic, error handling, and post-processing steps.
"""
logger.info("Running stateless turn")
turn_messages = []
tokens_used = {}
messages = order_messages(messages)
# Sort messages by the specified ordering
#messages = order_messages(messages)
# Merge any additional tool configs
tool_configs = tool_configs + additional_tool_configs
validation_error_msg, validation_error_type = check_request_validity(
messages=messages,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
max_overall_turns=max_overall_turns
)
# Determine if this is a greeting turn
greeting_turn = not any(msg.get("role") != "system" for msg in messages)
turn_messages = []
# Initialize tokens_used as a dictionary
tokens_used = {"total": 0, "prompt": 0, "completion": 0}
if validation_error_msg and validation_error_type == ErrorType.FATAL.value:
logger.error(validation_error_msg)
return handle_error(
error_tool_call=error_tool_call,
error_msg=validation_error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
# Extract special agent configs
post_processing_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.POST_PROCESSING.value)
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
agent_data = state.get("agent_data", [])
universal_sys_msg = ""
# If not a greeting turn, localize the last user or system messages
if not greeting_turn:
latest_assistant_msg = get_latest_assistant_msg(messages)
universal_sys_msg = get_universal_system_message(messages)
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
msg_type = latest_non_assistant_msgs[-1]["role"]
# Determine the last agent from state/config
last_agent_name = get_last_agent_name(
state=state,
agent_configs=agent_configs,
@ -339,12 +130,12 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
latest_assistant_msg=latest_assistant_msg,
start_turn_with_start_agent=start_turn_with_start_agent
)
logger.info("Localizing message history")
# Localize history
if msg_type == "user":
messages = reset_current_turn(messages)
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
agent_data = clean_up_history(agent_data)
#agent_data = clean_up_history(agent_data)
agent_data = add_recent_messages_to_history(
recent_messages=latest_non_assistant_msgs,
last_agent_name=last_agent_name,
@ -352,205 +143,106 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
messages=messages,
parent_has_child_history=parent_has_child_history
)
else:
# For a greeting turn, we assume the last agent is the start_agent_name
last_agent_name = start_agent_name
state["agent_data"] = agent_data
# Initialize all agents
logger.info("Initializing agents")
all_agents = get_agents(
all_agents, new_agents = get_agents(
agent_configs=agent_configs,
tool_configs=tool_configs,
available_tool_mappings=available_tool_mappings,
agent_data=state.get("agent_data", []),
agent_data=agent_data,
localize_history=localize_history,
start_turn_with_start_agent=start_turn_with_start_agent,
children_aware_of_parent=children_aware_of_parent,
universal_sys_msg=universal_sys_msg
)
if not all_agents:
logger.error("No agents initialized")
return handle_error(
error_tool_call=error_tool_call,
error_msg="No agents initialized"
)
# Prepare escalation agent
if greeting_turn:
greeting_msg = get_prompt_by_type(prompt_configs, PromptType.GREETING.value)
if not greeting_msg:
logger.error("Greeting prompt not found and messages is empty")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Greeting prompt not found and messages is empty",
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
greeting_msg_internal = {
"content": greeting_msg,
"role": "assistant",
"sender": start_agent_name,
"response_type": "internal",
"created_at": datetime.now().isoformat(),
"current_turn": True
}
greeting_msg_external = deepcopy(greeting_msg_internal)
greeting_msg_external["response_type"] = "external"
greeting_msg_external["sender"] = greeting_msg_external["sender"] + ' >> External'
turn_messages.extend([greeting_msg_internal, greeting_msg_external])
response = Response(
messages=turn_messages,
tokens_used={},
agent=get_agent_by_name(start_agent_name, all_agents),
error_msg=''
)
return create_final_response(
response=response,
turn_messages=turn_messages,
messages=messages,
tokens_used=tokens_used,
all_agents=all_agents,
return_diff_messages=return_diff_messages
)
error_escalation_agent = deepcopy(get_agent_by_type(all_agents, AgentRole.ESCALATION.value))
if not error_escalation_agent:
logger.error("Escalation agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Escalation agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
error_escalation_agent = clear_agent_fields(error_escalation_agent)
error_escalation_agent = add_error_escalation_instructions(error_escalation_agent)
logger.info(f"Initialized {len(all_agents)} agents")
logger.debug("Getting last agent")
# Get the last agent and validate
last_agent = get_agent_by_name(last_agent_name, all_agents)
if not last_agent:
logger.error("Last agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Last agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
state=state
)
last_new_agent = get_agent_by_name(last_agent_name, new_agents)
# Gather external tools for Swarm
external_tools = get_external_tools(tool_configs)
logger.info(f"Found {len(external_tools)} external tools")
logger.debug("Initializing Swarm client")
swarm_client = Swarm()
if not validation_error_msg:
response = swarm_client.run(
agent=last_agent,
messages=messages,
execute_tools=True,
external_tools=external_tools,
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
# If no validation error yet, proceed with the main run
response = swarm_run(
agent=last_new_agent,
messages=messages,
execute_tools=True,
external_tools=external_tools,
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used
)
# Initialize response.messages if it doesn't exist
if not hasattr(response, 'messages'):
response.messages = []
# Convert the ResponseOutputMessage to a standard message format
if hasattr(response, 'new_items') and response.new_items and hasattr(response.new_items[-1], 'raw_item'):
raw_item = response.new_items[-1].raw_item
# Extract text content from ResponseOutputText objects
content = ""
if hasattr(raw_item, 'content') and raw_item.content:
for content_item in raw_item.content:
if hasattr(content_item, 'text'):
content += content_item.text
# Create a standard message dictionary
standard_message = {
"role": raw_item.role if hasattr(raw_item, 'role') else "assistant",
"content": content,
"sender": last_agent.name,
"created_at": None,
"response_type": "internal"
}
# Add the converted message to response messages
response.messages.append(standard_message)
# Use a dictionary for tokens_used instead of a hard-coded integer
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Dummy values as placeholders
# Ensure turn_messages can be extended with response.messages
if hasattr(response, 'messages') and isinstance(response.messages, list):
turn_messages.extend(response.messages)
logger.info(f"Completed run of agent: {last_agent.name}")
if validation_error_msg and validation_error_type == ErrorType.ESCALATE.value or response.error_msg:
logger.info(f"Error raised in turn: {response.error_msg}")
response_sender_agent_name = response.agent.name
if escalate_errors and response_sender_agent_name != error_escalation_agent.name:
response = client.run(
agent=error_escalation_agent,
messages=[],
execute_tools=True,
external_tools=external_tools,
localize_history=False,
parent_has_child_history=False,
max_messages_per_turn=max_messages_per_error_escalation_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info(f"Completed run of escalation agent: {error_escalation_agent.name}")
if response.error_msg:
logger.info(f"Error raised in escalation turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
else:
logger.info(f"Error raised in turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
if post_processing_agent_config:
response = post_process_response(
messages=turn_messages,
post_processing_agent_name=post_processing_agent_config.get("name", "Post Processing agent"),
post_process_instructions=post_processing_agent_config.get("instructions", ""),
style_prompt=get_prompt_by_type(prompt_configs, PromptType.STYLE.value),
context='',
model=post_processing_agent_config.get("model", "gpt-4o"),
tokens_used=tokens_used,
last_agent=last_agent
)
tokens_used = response.tokens_used
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info("Response post-processed")
else:
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
logger.info(f"Completed run of agent: {last_agent.name}")
# Otherwise, duplicate the last response as external
logger.info("No post-processing agent found. Duplicating last response and setting to external.")
if turn_messages:
duplicate_msg = deepcopy(turn_messages[-1])
duplicate_msg["response_type"] = "external"
duplicate_msg["sender"] = duplicate_msg["sender"] + ' >> External'
response = Response(
duplicate_msg["sender"] += " >> External"
# Ensure tokens_used remains a proper dictionary
if not isinstance(tokens_used, dict):
tokens_used = {"total": 100, "prompt": 50, "completion": 50} # Default values if not a dictionary
response = create_response(
messages=[duplicate_msg],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info("Last response duplicated and set to external")
if guardrails_agent_config:
logger.info("Guardrails agent not implemented (ignoring)")
pass
# Ensure response has messages attribute
if hasattr(response, 'messages') and isinstance(response.messages, list):
turn_messages.extend(response.messages)
if not state or not state.get("last_agent_name"):
logger.error("State is empty or last agent name is not set")
raise ValueError("State is empty or last agent name is not set")
# Finalize the response
return create_final_response(
response=response,
turn_messages=turn_messages,
@ -558,4 +250,4 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
tokens_used=tokens_used,
all_agents=all_agents,
return_diff_messages=return_diff_messages
)
)

View file

@ -3,7 +3,7 @@ from src.utils.common import generate_llm_output
import os
import copy
from src.swarm.types import Response, Agent
from .swarm_wrapper import Agent, Response, create_response
from src.utils.common import common_logger, generate_openai_output, update_tokens_used
logger = common_logger
@ -20,12 +20,12 @@ def classify_hallucination(context: str, assistant_response: str, chat_history:
Returns:
str: Verdict indicating level of hallucination:
'yes-absolute' - completely supported by context
'yes-common-sensical' - supported with common sense interpretation
'yes-common-sensical' - supported with common sense interpretation
'no-absolute' - not supported by context
'no-subtle' - not supported but difference is subtle
"""
chat_history_str = "\n".join([f"{message['role']}: {message['content']}" for message in chat_history])
prompt = f"""
You are a guardrail agent. Your job is to check if the response is hallucinating.
@ -51,40 +51,40 @@ def classify_hallucination(context: str, assistant_response: str, chat_history:
no-absolute: not supported by the context
no-subtle: not supported by the context but the difference is subtle
Output of of the classes:
verdict : yes-absolute/yes-common-sensical/no-absolute/no-subtle
Output of of the classes:
verdict : yes-absolute/yes-common-sensical/no-absolute/no-subtle
Example 1: The response is completely supported by the context.
User Input:
User Input:
Context: "Our airline provides complimentary meals and beverages on all international flights. Passengers are allowed one carry-on bag and one personal item."
Chat History:
Chat History:
User: "Do international flights with your airline offer free meals?"
Response: "Yes, all international flights with our airline offer free meals and beverages."
Output: verdict: yes-absolute
Example 2: The response is generally true and could be deduced with common sense interpretation, though not explicitly stated in the context.
User Input:
User Input:
Context: "Flights may experience delays due to weather conditions. In such cases, the airline staff will provide updates at the airport."
Chat History:
Chat History:
User: "Will there be announcements if my flight is delayed?"
Response: "Yes, if your flight is delayed, there will be announcements at the airport."
Output: verdict: yes-common-sensical
Output: verdict: yes-common-sensical
Example 3: The response is not supported by the context and contains glaring inaccuracies.
User Input:
User Input:
Context: "You can cancel your ticket online up to 24 hours before the flight's departure time and receive a full refund."
Chat History:
Chat History:
User: "Can I get a refund if I cancel 12 hours before the flight?"
Response: "Yes, you can get a refund if you cancel 12 hours before the flight."
Output: verdict: no-absolute
Example 4: The response is not supported by the context but the difference is subtle.
User Input:
User Input:
Context: "Our frequent flyer program offers discounts on checked bags for members who have achieved Gold status."
Chat History:
Chat History:
User: "As a member, do I get discounts on checked bags?"
Response: "Yes, members of our frequent flyer program get discounts on checked bags."
Output: verdict: no-subtle
Output: verdict: no-subtle
"""
messages = [
{
@ -105,7 +105,7 @@ def post_process_response(messages: list, post_processing_agent_name: str, post_
logger.debug(f"Pending message keys: {pending_msg.keys()}")
skip = False
if pending_msg.get("tool_calls"):
logger.info("Last message is a tool call, skipping post processing and setting last message to external")
skip = True
@ -113,11 +113,11 @@ def post_process_response(messages: list, post_processing_agent_name: str, post_
elif not pending_msg['response_type'] == "internal":
logger.info("Last message is not internal, skipping post processing and setting last message to external")
skip = True
elif not pending_msg['content']:
logger.info("Last message has no content, skipping post processing and setting last message to external")
skip = True
elif not post_process_instructions:
logger.info("No post process instructions, skipping post processing and setting last message to external")
skip = True
@ -131,7 +131,7 @@ def post_process_response(messages: list, post_processing_agent_name: str, post_
error_msg=''
)
return response
agent_history_str = f"\n{'*'*100}\n".join([f"Role: {message['role']} | Content: {message.get('content', 'None')} | Tool Calls: {message.get('tool_calls', 'None')}" for message in agent_history[:-1]])
logger.debug(f"Agent history: {agent_history_str}")
@ -147,7 +147,7 @@ def post_process_response(messages: list, post_processing_agent_name: str, post_
{post_process_instructions}
------------------------------------------------------------------------
# CHAT HISTORY
Here is the chat history:
@ -186,7 +186,7 @@ def post_process_response(messages: list, post_processing_agent_name: str, post_
Here is the response that the agent has generated:
{pending_msg['content']}
"""
prompt += agent_response_and_instructions

View file

@ -0,0 +1,384 @@
from src.swarm.core import Swarm
from src.swarm.types import Agent as SwarmAgent, Response as SwarmResponse
import logging
import json
# Import helper functions needed for get_agents
from .helpers.access import (
get_agent_data_by_name, get_agent_by_name, get_tool_config_by_name,
get_tool_config_by_type
)
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.instructions import (
add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents,
add_rag_instructions_to_agent, add_universal_system_message_to_agent
)
from agents import Agent as NewAgent, Runner, FunctionTool, function_tool
# Add import for OpenAI functionality
from src.utils.common import generate_openai_output
# Create a dedicated logger for swarm wrapper
logger = logging.getLogger("swarm_wrapper")
logger.setLevel(logging.INFO)
# Re-export the types from src.swarm.types
Agent = SwarmAgent
Response = SwarmResponse
def create_python_tool(tool_name, tool_description, tool_params):
"""
Return a Python function definition (as a string) with the given name, docstring,
and parameters derived from a JSON-schema-like dictionary.
:param tool_name: str
Name of the function to generate.
:param tool_description: str
High-level docstring/description for the function.
:param tool_params: dict
A JSON Schemastyle definition with 'parameters':
{
"parameters": {
"type": "object",
"properties": {
"<param_name>": {
"type": "string" | "integer" | "number" | "boolean" | "object" | "array",
"description": "..."
},
...
}
}
}
:return: str
The function definition as a string (no shebang or `if __name__ == "__main__"`).
"""
# Maps JSON Schema types to Python type hints
type_map = {
"string": "str",
"integer": "int",
"number": "float",
"boolean": "bool",
"object": "dict",
"array": "list",
}
# Extract the properties from the JSON-schema-like dict
properties = tool_params.get("parameters", {}).get("properties", {})
# Build the function signature and docstring pieces
signature_parts = []
docstring_params = []
for param_name, param_info in properties.items():
# Default to "str" if no specific type is given
json_type = param_info.get("type", "string")
python_type = type_map.get(json_type, "str")
description = param_info.get("description", "")
# e.g. "orderId: str"
signature_parts.append(f"{param_name}: {python_type}")
# Build docstring lines (reST style)
docstring_params.append(f":param {param_name}: {description}")
docstring_params.append(f":type {param_name}: {python_type}")
signature = ", ".join(signature_parts)
params_docstring_text = "\n ".join(docstring_params)
function_docstring = f'''\"\"\"{tool_description}
{params_docstring_text}
\"\"\"'''
# Return only the function definition (no shebang or main guard)
# Return the function definition including the @function_tool decorator
function_code = f'''@function_tool
async def {tool_name}({signature}):
{function_docstring}
# TODO: Implement your logic here
messages = [
{{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {tool_description}. Generate a realistic response as if the tool was actually executed with the given parameters."}},
{{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}'. The response should be concise and focused on what the tool would actually return."}}
]
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
return(response_content)
'''
return function_code
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings,
agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg):
"""
Creates and initializes Agent objects based on their configurations and connections.
This function also sets up parent-child relationships, transfer instructions, and
universal system messages.
"""
if not isinstance(agent_configs, list):
raise ValueError("Agents config is not a list in get_agents")
if not isinstance(tool_configs, list):
raise ValueError("Tools config is not a list in get_agents")
agents = []
new_agents = []
new_agent_to_children = {}
new_agent_name_to_index = {}
# Create Agent objects from config
for agent_config in agent_configs:
logger.debug(f"Processing config for agent: {agent_config['name']}")
# If hasRagSources, append the RAG tool to the agent's tools
if agent_config.get("hasRagSources", False):
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
agent_config["tools"].append(rag_tool_name)
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
# Prepare tool lists for this agent
external_tools = []
candidate_parent_functions = {}
child_functions = {}
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
new_tools = []
for tool_name in agent_config["tools"]:
tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config:
external_tools.append({
"type": "function",
"function": tool_config
})
# Create a dummy function to mock the tool execution
# Use a closure to capture the tool_name variable properly
def create_mock_tool_function(tool_name):
@function_tool(
name=tool_name,
description=tool_config.get("description", ""),
params_json_schema=tool_config.get("parameters", {})
)
def mock_tool_execution(**kwargs):
# Docstring will be set after function definition
logger.info(f"Executing tool {tool_name} with params: {kwargs}")
# Create a prompt for OpenAI to generate a realistic response
messages = [
{"role": "system", "content": f"You are simulating the execution of a tool called '{tool_name}'. The tool has this description: {tool_config.get('description', 'No description available')}. Generate a realistic response as if the tool was actually executed with the given parameters."},
{"role": "user", "content": f"Generate a realistic response for the tool '{tool_name}' with these parameters: {json.dumps(kwargs)}. The response should be concise and focused on what the tool would actually return."}
]
try:
# Call OpenAI to generate a realistic response
response_content = generate_openai_output(messages, output_type='text', model="gpt-4o")
# Return a properly structured response with the OpenAI-generated content
return {
"status": "success",
"tool": tool_name,
"result": response_content,
"params_received": kwargs
}
except Exception as e:
logger.error(f"Error generating mock response for {tool_name}: {str(e)}")
# Fall back to a simple mock response if OpenAI call fails
return {
"status": "success",
"tool": tool_name,
"result": f"Simulated result for {tool_name}",
"params_received": kwargs,
"error": str(e)
}
# Set the docstring to use the tool's description
mock_tool_execution.__doc__ = tool_config.get("description", "Mock function that simulates tool execution")
return mock_tool_execution
tool_code = create_python_tool(tool_name, tool_config.get("description", ""), tool_config.get("parameters", {}))
local_namespace = {"function_tool": function_tool, "generate_openai_output": generate_openai_output}
# Execute the generated code so `my_tool` is defined in local_namespace
exec(tool_code, local_namespace)
print(tool_code)
my_tool_func = local_namespace[tool_name]
new_tools.append(my_tool_func)
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
logger.warning(f"Tool {tool_name} not found in tool_configs")
# Localize history (if applicable)
history = []
this_agent_data = get_agent_data_by_name(agent_config["name"], agent_data)
if this_agent_data and localize_history:
history = this_agent_data.get("history", [])
# Create the agent object
logger.debug(f"Creating Agent object for {agent_config['name']}")
try:
agent = Agent(
name=agent_config["name"],
type=agent_config.get("type", "default"),
instructions=agent_config["instructions"],
description=agent_config.get("description", ""),
internal_tools=[],
external_tools=external_tools,
candidate_parent_functions=candidate_parent_functions,
child_functions=child_functions,
model=agent_config["model"],
respond_to_user=agent_config.get("respond_to_user", False),
history=history,
children_names=agent_config.get("connectedAgents", []),
most_recent_parent=None
)
new_agent = NewAgent(
name=agent_config["name"],
instructions=agent_config["instructions"],
handoff_description=agent_config["description"],
tools=new_tools,
model=agent_config["model"]
)
new_agent_to_children[agent_config["name"]] = agent_config.get("connectedAgents", [])
new_agent_name_to_index[agent_config["name"]] = len(new_agents)
new_agents.append(new_agent)
agents.append(agent)
logger.debug(f"Successfully created agent: {agent_config['name']}")
except Exception as e:
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
raise
# Reattach most_recent_parent if it exists
for agent in agents:
this_agent_data = get_agent_data_by_name(agent.name, agent_data)
if this_agent_data:
most_recent_parent_name = this_agent_data.get("most_recent_parent_name", "")
if most_recent_parent_name:
parent_agent = get_agent_by_name(most_recent_parent_name, agents)
if parent_agent:
agent.most_recent_parent = parent_agent
# Attach children
logger.info("Adding children agents to parent agents")
for agent in agents:
agent.children = {
potential_child.name: potential_child
for potential_child in agents
if potential_child.name in agent.children_names
}
# Generate transfer functions for child agents
logger.info("Generating transfer functions for transferring to children agents")
transfer_functions = {
agent.name: create_transfer_function_to_agent(agent)
for agent in agents
}
# Add transfer functions to parent agents for each child
logger.info("Adding transfer functions for parents to transfer to children")
for agent in agents:
for child in agent.children.values():
agent.child_functions[child.name] = transfer_functions[child.name]
# Add parent-related instructions
logger.info("Adding child transfer-related instructions to parent agents")
for agent in agents:
if agent.children:
add_transfer_instructions_to_parent_agents(agent, agent.children, transfer_functions)
# Finally add a universal system message to all agents
for agent in agents:
add_universal_system_message_to_agent(agent, universal_sys_msg)
for new_agent in new_agents:
# Initialize the handoffs attribute if it doesn't exist
if not hasattr(new_agent, 'handoffs'):
new_agent.handoffs = []
# Look up the agent's children from the old agent and create a list called handoffs in new_agent with pointers to the children in new_agents
new_agent.handoffs = [new_agents[new_agent_name_to_index[child]] for child in new_agent_to_children[new_agent.name]]
return agents, new_agents
def create_response(messages=None, tokens_used=None, agent=None, error_msg=''):
"""
Create a Response object with the given parameters.
Args:
messages: List of messages
tokens_used: Dictionary tracking token usage
agent: The agent that generated the response
error_msg: Error message if any
Returns:
Response object
"""
if messages is None:
messages = []
if tokens_used is None:
tokens_used = {}
return Response(
messages=messages,
tokens_used=tokens_used,
agent=agent,
error_msg=error_msg
)
def run(
agent,
messages,
execute_tools=True,
external_tools=None,
localize_history=True,
parent_has_child_history=True,
max_messages_per_turn=10,
tokens_used=None
):
"""
Wrapper function for initializing and running the Swarm client.
Args:
agent: The agent to run
messages: List of messages for the agent to process
execute_tools: Whether to execute tools or just return tool calls
external_tools: List of external tools available to the agent
localize_history: Whether to localize history for the agent
parent_has_child_history: Whether parent agents have access to child agent history
max_messages_per_turn: Maximum number of messages to process in a turn
tokens_used: Dictionary tracking token usage
Returns:
Response object from the Swarm client
"""
logger.info(f"Initializing Swarm client for agent: {agent.name}")
# Initialize default parameters
if external_tools is None:
external_tools = []
if tokens_used is None:
tokens_used = {}
# Format messages to ensure they're compatible with the OpenAI API
formatted_messages = []
for msg in messages:
# Check if the message has the expected format
if isinstance(msg, dict) and "content" in msg:
# Make sure the message has the required fields for OpenAI API
formatted_msg = {
"role": msg.get("role", "user"),
"content": msg["content"]
}
formatted_messages.append(formatted_msg)
else:
# If the message is just a string, assume it's a user message
formatted_messages.append({
"role": "user",
"content": str(msg)
})
# Run the agent with the formatted messages
response2 = Runner.run_sync(agent, formatted_messages)
logger.info(f"Completed Swarm run for agent: {agent.name}")
return response2