Add agents with custom swarm implementation

This commit is contained in:
akhisud3195 2025-01-13 18:20:38 +05:30
parent 24c4f6e552
commit a19dedd59f
35 changed files with 3413 additions and 0 deletions

View file

View file

View file

@ -0,0 +1,83 @@
from flask import Flask, request, jsonify
from src.graph.core import run_turn
from datetime import datetime
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
from src.utils.common import common_logger, read_json_from_file
logger = common_logger
app = Flask(__name__)
@app.route("/")
def home():
return "Hello, World!"
@app.route("/chat", methods=["POST"])
def chat():
print('='*200)
logger.info('='*200)
try:
data = request.get_json()
print('Complete request:')
logger.info('Complete request')
print(data)
logger.info(data)
print('-'*200)
logger.info('-'*200)
start_time = datetime.now()
config = read_json_from_file("./configs/default_config.json")
resp_messages, resp_tokens_used, resp_state = run_turn(
messages=data.get("messages", []),
start_agent_name=data.get("startAgent", ""),
agent_configs=data.get("agents", []),
tool_configs=data.get("tools", []),
localize_history=config.get("localize_history", True),
return_diff_messages=config.get("return_diff_messages", True),
prompt_configs=data.get("prompts", []),
start_turn_with_start_agent=config.get("start_turn_with_start_agent", False),
children_aware_of_parent=config.get("children_aware_of_parent", False),
parent_has_child_history=config.get("parent_has_child_history", True),
state=data.get("state", {}),
additional_tool_configs=[RAG_TOOL, CLOSE_CHAT_TOOL],
max_messages_per_turn=config.get("max_messages_per_turn", 2),
max_messages_per_error_escalation_turn=config.get("max_messages_per_error_escalation_turn", 2),
escalate_errors=config.get("escalate_errors", True),
max_overall_turns=config.get("max_overall_turns", 10)
)
print('-'*200)
logger.info('-'*200)
out = {
"messages": resp_messages,
"tokens_used": resp_tokens_used,
"state": resp_state,
}
print("Output: ")
logger.info(f"Output: ")
for k, v in out.items():
print(f"{k}: {v}")
print('*'*200)
logger.info(f"{k}: {v}")
logger.info('*'*200)
print("Processing time:")
print('='*200)
logger.info('='*200)
print(f"Processing time: {datetime.now() - start_time}")
logger.info(f"Processing time: {datetime.now() - start_time}")
return jsonify(out)
except Exception as e:
print(e)
logger.error(f"Error: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
print("Starting Flask server...")
app.run(port=4040, debug=True)

View file

View file

@ -0,0 +1,476 @@
import os
import sys
from src.swarm.types import Agent
from src.swarm.core import Swarm
from .guardrails import post_process_response
from .tools import create_error_tool_call
from .types import AgentRole, PromptType, ErrorType
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from src.utils.common import common_logger
from copy import deepcopy
logger = common_logger
def order_messages(messages):
# Arrange keys in specified order
ordered_messages = []
for msg in messages:
ordered = {}
msg = {k: v for k, v in msg.items() if v is not None}
# Add keys in specified order if they exist
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
if key in msg:
ordered[key] = msg[key]
# Add remaining keys in alphabetical order
for key in sorted(msg.keys()):
if key not in ['role', 'sender', 'content', 'created_at', 'timestamp']:
ordered[key] = msg[key]
ordered_messages.append(ordered)
return ordered_messages
def clean_up_history(agent_data):
for data in agent_data:
data["history"] = order_messages(data["history"])
return agent_data
def clear_agent_fields(agent):
agent.children = {}
agent.parent_function = None
agent.candidate_parent_functions = {}
agent.child_functions = {}
if agent.most_recent_parent:
agent.history = []
return agent
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent):
# Create Agent objects
agents = []
if not isinstance(agent_configs, list):
raise ValueError("Agents config is not a list in get_agents")
if not isinstance(tool_configs, list):
raise ValueError("Tools config is not a list in get_agents")
for agent_config in agent_configs:
logger.debug(f"Processing config for agent: {agent_config['name']}")
# Get tools for this agent
external_tools = []
internal_tools = []
candidate_parent_functions = {}
child_functions = {}
logger.debug(f"Finding tools for agent {agent_config['name']}")
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
if agent_config.get("hasRagSources", False):
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
agent_config["tools"].append(rag_tool_name)
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
for tool_name in agent_config["tools"]:
logger.debug(f"Looking for tool config: {tool_name}")
tool_config = get_tool_config_by_name(tool_configs, tool_name)
if tool_config:
if tool_name in available_tool_mappings:
internal_tools.append(available_tool_mappings[tool_name])
else:
external_tools.append({
"type": "function",
"function": tool_config
})
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
else:
logger.warning(f"Tool {tool_name} not found in tool_configs")
history = []
this_agent_data = get_agent_data_by_name(agent_config["name"], agent_data)
if this_agent_data:
if localize_history:
history = this_agent_data.get("history", [])
# Create agent
logger.debug(f"Creating Agent object for {agent_config['name']}")
logger.debug(f"Using model: {agent_config['model']}")
logger.debug(f"Number of tools being added: Internal - {len(internal_tools)} | External - {len(external_tools)}")
try:
agent = Agent(
name=agent_config["name"],
type=agent_config.get("type", "default"),
instructions=agent_config["instructions"],
description=agent_config.get("description", ""),
internal_tools=internal_tools,
external_tools=external_tools,
candidate_parent_functions=candidate_parent_functions,
child_functions=child_functions,
model=agent_config["model"],
respond_to_user=agent_config.get("respond_to_user", False),
history=history,
children_names=agent_config.get("connectedAgents", []),
most_recent_parent=None
)
agents.append(agent)
logger.debug(f"Successfully created agent: {agent_config['name']}")
except Exception as e:
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
raise
# Adding most recent parents to agents
for agent in agents:
most_recent_parent = None
this_agent_data = get_agent_data_by_name(agent.name, agent_data)
if this_agent_data:
most_recent_parent_name = this_agent_data.get("most_recent_parent_name", "")
if most_recent_parent_name:
most_recent_parent = get_agent_by_name(most_recent_parent_name, agents) if most_recent_parent_name else None
if most_recent_parent:
agent.most_recent_parent = most_recent_parent
# Adding children agents to parent agents
logger.info("Adding children agents to parent agents")
for agent in agents:
agent.children = {agent_.name: agent_ for agent_ in agents if agent_.name in agent.children_names}
# Generate transfer functions for transferring to children agents
logger.info("Generating transfer functions for transferring to children agents")
transfer_functions = {
agent.name: create_transfer_function_to_agent(agent)
for agent in agents
}
# Add transfer functions for parents to transfer to children
logger.info("Adding transfer functions for parents to transfer to children")
for agent in agents:
for child in agent.children.values():
agent.child_functions[child.name] = transfer_functions[child.name]
# Add transfer-related instructions to parent agents
logger.info("Adding child transfer-related instructions to parent agents")
for agent in agents:
if agent.children:
agent = add_transfer_instructions_to_parent_agents(agent, agent.children, transfer_functions)
# Generate and append duplicate transfer functions for children to transfer to parent agents
logger.info("Generating duplicate transfer functions for children to transfer to parent agents")
for agent in agents:
for child in agent.children.values():
func = create_transfer_function_to_parent_agent(
parent_agent=agent,
children_aware_of_parent=children_aware_of_parent,
transfer_functions=transfer_functions
)
child.candidate_parent_functions[agent.name] = func
for agent in agents:
if agent.candidate_parent_functions:
agent = add_transfer_instructions_to_child_agents(
child=agent,
children_aware_of_parent=children_aware_of_parent
)
for agent in agents:
if agent.most_recent_parent:
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
return agents
def check_request_validity(messages, agent_configs, tool_configs, prompt_configs, max_overall_turns):
error_msg = ""
error_type = ErrorType.ESCALATE.value
# Limits checks
external_messages_count = sum(1 for msg in messages if msg.get("response_type") == "external")
if external_messages_count >= max_overall_turns:
error_msg = f"Max overall turns reached: {max_overall_turns}"
# Empty checks
if not messages:
error_msg = "Messages list is empty"
# Empty checks --> Fatal
if not agent_configs:
error_msg = "Agent configs list is empty"
error_type = ErrorType.FATAL.value
# Type checks --> Fatal
for arg in [messages, agent_configs, tool_configs, prompt_configs]:
if not isinstance(arg, list):
error_msg = f"{arg} is not a list"
error_type = ErrorType.FATAL.value
# Post processing agent, guardrails and escalation agent check - there should be at max one agent with type "post_processing_agent", "guardrails_agent" and "escalation_agent" respectively --> Fatal
post_processing_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.POST_PROCESSING.value)
guardrails_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.GUARDRAILS.value)
escalation_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.ESCALATION.value)
if post_processing_agent_count > 1 or guardrails_agent_count > 1 or escalation_agent_count > 1:
error_msg = "Invalid post processing agent or guardrails agent count - expected at most 1"
error_type = ErrorType.FATAL.value
# All agent config should have: name, instructions, model --> Fatal
for agent_config in agent_configs:
if not all(key in agent_config for key in ["name", "instructions", "model"]):
missing_keys = [key for key in ["name", "instructions", "tools", "model"] if key not in agent_config]
error_msg = f"Invalid agent config - missing keys: {missing_keys}"
error_type = ErrorType.FATAL.value
# Check for cycles in the agent config graph. Raise error if cycle is found, along with the agents involved in the cycle.
def find_cycles(agent_name, agent_configs, visited=None, path=None):
if visited is None:
visited = set()
if path is None:
path = []
visited.add(agent_name)
path.append(agent_name)
agent_config = get_agent_config_by_name(agent_name, agent_configs)
if not agent_config:
return None
for child_name in agent_config.get("connectedAgents", []):
if child_name in path:
cycle = path[path.index(child_name):]
cycle.append(child_name)
return cycle
if child_name not in visited:
cycle = find_cycles(child_name, agent_configs, visited, path)
if cycle:
return cycle
path.pop()
return None
for agent_config in agent_configs:
if agent_config.get("name") in agent_config.get("connectedAgents", []):
error_msg = f"Cycle detected in agent config graph - agent {agent_config.get('name')} is connected to itself"
cycle = find_cycles(agent_config.get("name"), agent_configs)
if cycle:
cycle_str = " -> ".join(cycle)
error_msg = f"Cycle detected in agent config graph: {cycle_str}"
return error_msg, error_type
def handle_error(error_tool_call, error_msg, return_diff_messages, messages, turn_messages, state, tokens_used):
resp_messages = turn_messages if return_diff_messages else messages + turn_messages
resp_messages.extend([create_error_tool_call(error_msg)])
if error_tool_call:
return resp_messages, tokens_used, state
else:
raise ValueError(error_msg)
def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={}, localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False, children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[], error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4, escalate_errors=True, max_overall_turns=10):
logger.info("Running stateless turn")
turn_messages = []
tokens_used = {}
messages = order_messages(messages)
tool_configs = tool_configs + additional_tool_configs
validation_error_msg, validation_error_type = check_request_validity(
messages=messages,
agent_configs=agent_configs,
tool_configs=tool_configs,
prompt_configs=prompt_configs,
max_overall_turns=max_overall_turns
)
if validation_error_msg and validation_error_type == ErrorType.FATAL.value:
logger.error(validation_error_msg)
return handle_error(
error_tool_call=error_tool_call,
error_msg=validation_error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
post_processing_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.POST_PROCESSING.value)
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
latest_assistant_msg = get_latest_assistant_msg(messages)
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
msg_type = latest_non_assistant_msgs[-1]["role"]
last_agent_name = get_last_agent_name(
state=state,
agent_configs=agent_configs,
start_agent_name=start_agent_name,
msg_type=msg_type,
latest_assistant_msg=latest_assistant_msg,
start_turn_with_start_agent=start_turn_with_start_agent
)
logger.info("Localizing message history")
agent_data = state.get("agent_data", [])
if msg_type == "user":
messages = reset_current_turn(messages)
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
agent_data = clean_up_history(agent_data)
agent_data = add_recent_messages_to_history(
recent_messages=latest_non_assistant_msgs,
last_agent_name=last_agent_name,
agent_data=agent_data,
messages=messages,
parent_has_child_history=parent_has_child_history
)
state["agent_data"] = agent_data
logger.info("Initializing agents")
all_agents = get_agents(
agent_configs=agent_configs,
tool_configs=tool_configs,
available_tool_mappings=available_tool_mappings,
agent_data=state.get("agent_data", []),
localize_history=localize_history,
start_turn_with_start_agent=start_turn_with_start_agent,
children_aware_of_parent=children_aware_of_parent,
)
if not all_agents:
logger.error("No agents initialized")
return handle_error(
error_tool_call=error_tool_call,
error_msg="No agents initialized"
)
error_escalation_agent = deepcopy(get_agent_by_type(all_agents, AgentRole.ESCALATION.value))
if not error_escalation_agent:
logger.error("Escalation agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Escalation agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
error_escalation_agent = clear_agent_fields(error_escalation_agent)
error_escalation_agent = add_error_escalation_instructions(error_escalation_agent)
logger.info(f"Initialized {len(all_agents)} agents")
logger.debug("Getting last agent")
last_agent = get_agent_by_name(last_agent_name, all_agents)
if not last_agent:
logger.error("Last agent not found")
return handle_error(
error_tool_call=error_tool_call,
error_msg="Last agent not found",
return_diff_messages=return_diff_messages,
messages=messages,
state=state
)
external_tools = get_external_tools(tool_configs)
logger.info(f"Found {len(external_tools)} external tools")
logger.debug("Initializing Swarm client")
client = Swarm()
if not validation_error_msg:
response = client.run(
agent=last_agent,
messages=messages,
execute_tools=True,
external_tools=external_tools,
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info(f"Completed run of agent: {last_agent.name}")
if validation_error_msg and validation_error_type == ErrorType.ESCALATE.value or response.error_msg:
logger.info(f"Error raised in turn: {response.error_msg}")
response_sender_agent_name = response.agent.name
if escalate_errors and response_sender_agent_name != error_escalation_agent.name:
response = client.run(
agent=error_escalation_agent,
messages=[],
execute_tools=True,
external_tools=external_tools,
localize_history=False,
parent_has_child_history=False,
max_messages_per_turn=max_messages_per_error_escalation_turn,
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info(f"Completed run of escalation agent: {error_escalation_agent.name}")
if response.error_msg:
logger.info(f"Error raised in escalation turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
else:
logger.info(f"Error raised in turn: {response.error_msg}")
return handle_error(
error_tool_call=error_tool_call,
error_msg=response.error_msg,
return_diff_messages=return_diff_messages,
messages=messages,
turn_messages=turn_messages,
state=state,
tokens_used=tokens_used
)
if post_processing_agent_config:
response = post_process_response(
messages=turn_messages,
post_processing_agent_name=post_processing_agent_config.get("name", "Post Processing agent"),
post_process_instructions=post_processing_agent_config.get("instructions", ""),
style_prompt=get_prompt_by_type(prompt_configs, PromptType.STYLE.value),
context='',
model=post_processing_agent_config.get("model", "gpt-4o"),
tokens_used=tokens_used,
last_agent=last_agent
)
tokens_used = response.tokens_used
response.messages = order_messages(response.messages)
turn_messages.extend(response.messages)
logger.info("Response post-processed")
if guardrails_agent_config:
logger.info("Guardrails agent not implemented (ignoring)")
pass
if not state or not state.get("last_agent_name"):
logger.error("State is empty or last agent name is not set")
raise ValueError("State is empty or last agent name is not set")
response.messages = turn_messages if return_diff_messages else messages + turn_messages
response.tokens_used = tokens_used
new_state = construct_state_from_response(response, all_agents)
return response.messages, response.tokens_used, new_state

View file

@ -0,0 +1,219 @@
# Guardrails
from src.utils.common import generate_llm_output
import os
import copy
from src.swarm.types import Response, Agent
from src.utils.common import common_logger, generate_openai_output, update_tokens_used
logger = common_logger
def classify_hallucination(context: str, assistant_response: str, chat_history: list, model: str) -> str:
"""
Checks if an assistant's response contains hallucinations by comparing against provided context.
Args:
context (str): The context/knowledge base to check the response against
assistant_response (str): The response from the assistant to validate
chat_history (list): List of previous chat messages for context
Returns:
str: Verdict indicating level of hallucination:
'yes-absolute' - completely supported by context
'yes-common-sensical' - supported with common sense interpretation
'no-absolute' - not supported by context
'no-subtle' - not supported but difference is subtle
"""
chat_history_str = "\n".join([f"{message['role']}: {message['content']}" for message in chat_history])
prompt = f"""
You are a guardrail agent. Your job is to check if the response is hallucinating.
------------------------------------------------------------------------
Here is the context:
{context}
------------------------------------------------------------------------
Here is the chat history message:
{chat_history_str}
------------------------------------------------------------------------
Here is the response:
{assistant_response}
------------------------------------------------------------------------
As a hallucination guardrail, your job is to go through each line of the response and check if it is completely supported by the context. Even if a single line is not supported, the response is no.
Output a single verdict for the entire response. don't provide any reasoning. The output classes are
yes-absolute: completely supported by the context
yes-common-sensical: but with some common sense interpretation
no-absolute: not supported by the context
no-subtle: not supported by the context but the difference is subtle
Output of of the classes:
verdict : yes-absolute/yes-common-sensical/no-absolute/no-subtle
Example 1: The response is completely supported by the context.
User Input:
Context: "Our airline provides complimentary meals and beverages on all international flights. Passengers are allowed one carry-on bag and one personal item."
Chat History:
User: "Do international flights with your airline offer free meals?"
Response: "Yes, all international flights with our airline offer free meals and beverages."
Output: verdict: yes-absolute
Example 2: The response is generally true and could be deduced with common sense interpretation, though not explicitly stated in the context.
User Input:
Context: "Flights may experience delays due to weather conditions. In such cases, the airline staff will provide updates at the airport."
Chat History:
User: "Will there be announcements if my flight is delayed?"
Response: "Yes, if your flight is delayed, there will be announcements at the airport."
Output: verdict: yes-common-sensical
Example 3: The response is not supported by the context and contains glaring inaccuracies.
User Input:
Context: "You can cancel your ticket online up to 24 hours before the flight's departure time and receive a full refund."
Chat History:
User: "Can I get a refund if I cancel 12 hours before the flight?"
Response: "Yes, you can get a refund if you cancel 12 hours before the flight."
Output: verdict: no-absolute
Example 4: The response is not supported by the context but the difference is subtle.
User Input:
Context: "Our frequent flyer program offers discounts on checked bags for members who have achieved Gold status."
Chat History:
User: "As a member, do I get discounts on checked bags?"
Response: "Yes, members of our frequent flyer program get discounts on checked bags."
Output: verdict: no-subtle
"""
messages = [
{
"role": "system",
"content": prompt,
},
]
response = generate_llm_output(messages, model)
return response
def post_process_response(messages: list, post_processing_agent_name: str, post_process_instructions: str, style_prompt: str = None, context: str = None, model: str = "gpt-4o", tokens_used: dict = {}, last_agent: Agent = None) -> dict:
agent_instructions = last_agent.instructions
agent_history = last_agent.history
# agent_instructions = ''
# agent_history = []
pending_msg = copy.deepcopy(messages[-1])
logger.debug(f"Pending message keys: {pending_msg.keys()}")
skip = False
if pending_msg.get("tool_calls"):
logger.info("Last message is a tool call, skipping post processing and setting last message to external")
skip = True
elif not pending_msg['response_type'] == "internal":
logger.info("Last message is not internal, skipping post processing and setting last message to external")
skip = True
elif not pending_msg['content']:
logger.info("Last message has no content, skipping post processing and setting last message to external")
skip = True
elif not post_process_instructions:
logger.info("No post process instructions, skipping post processing and setting last message to external")
skip = True
if skip:
pending_msg['response_type'] = "external"
response = Response(
messages=[],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
return response
agent_history_str = f"\n{'*'*100}\n".join([f"Role: {message['role']} | Content: {message.get('content', 'None')} | Tool Calls: {message.get('tool_calls', 'None')}" for message in agent_history[:-1]])
logger.debug(f"Agent history: {agent_history_str}")
prompt = f"""
# ROLE
You are a post processing agent responsible for rewriting a response generated by an agent, according to instructions provided below. Ensure that the response you produce adheres to the instructions provided to you (if any). Further, the response should not violate the instructions provided to the agent, the context that the agent has used, the chat history of the agent, the context and the style provided. Some of these might or might not be provided.
------------------------------------------------------------------------
# ADDITIONAL INSTRUCTIONS
Here are additional instructions that the admin might have configured for you:
{post_process_instructions}
------------------------------------------------------------------------
# CHAT HISTORY
Here is the chat history:
{agent_history_str}
"""
if context:
context_prompt = f"""
------------------------------------------------------------------------
# CONTEXT
Here is the context:
{context}
"""
prompt += context_prompt
if style_prompt:
style_prompt = f"""
------------------------------------------------------------------------
# STYLE PROMPT
Here is the style prompt:
{style_prompt}
"""
prompt += style_prompt
agent_response_and_instructions = f"""
------------------------------------------------------------------------
# AGENT INSTRUCTIONS
Here are the instructions to the agent generating the response:
{agent_instructions}
------------------------------------------------------------------------
# AGENT RESPONSE
Here is the response that the agent has generated:
{pending_msg['content']}
"""
prompt += agent_response_and_instructions
logger.debug(f"Sanitizing response for style. Original response: {pending_msg['content']}")
completion = generate_openai_output(
messages=[
{"role": "system", "content": prompt}
],
model = model,
return_completion=True
)
content = completion.choices[0].message.content
if content:
content = content.strip().lstrip().rstrip()
tokens_used = update_tokens_used(provider="openai", model=model, tokens_used=tokens_used, completion=completion)
logger.debug(f"Response after style check: {content}, tokens used: {tokens_used}")
pending_msg['content'] = content if content else pending_msg['content']
pending_msg['response_type'] = "external"
pending_msg['sender'] = pending_msg['sender'] + f' >> {post_processing_agent_name}'
response = Response(
messages=[pending_msg],
tokens_used=tokens_used,
agent=last_agent,
error_msg=''
)
return response

View file

@ -0,0 +1,48 @@
from src.utils.common import common_logger
logger = common_logger
def get_external_tools(tool_configs):
logger.debug("Getting external tools")
tools = [tool["name"] for tool in tool_configs]
logger.debug(f"Found {len(tools)} external tools")
return tools
def get_agent_by_name(agent_name, agents):
agent = next((a for a in agents if getattr(a, "name", None) == agent_name), None)
if not agent:
logger.error(f"Agent with name {agent_name} not found")
raise ValueError(f"Agent with name {agent_name} not found")
return agent
def get_agent_config_by_name(agent_name, agent_configs):
agent_config = next((ac for ac in agent_configs if ac.get("name") == agent_name), None)
if not agent_config:
logger.error(f"Agent config with name {agent_name} not found")
raise ValueError(f"Agent config with name {agent_name} not found")
return agent_config
def pop_agent_config_by_type(agent_configs, agent_type):
agent_config = next((ac for ac in agent_configs if ac.get("type") == agent_type), None)
if agent_config:
agent_configs.remove(agent_config)
return agent_config, agent_configs
def get_agent_by_type(agents, agent_type):
return next((a for a in agents if a.type == agent_type), None)
def get_prompt_by_type(prompt_configs, prompt_type):
return next((pc.get("prompt") for pc in prompt_configs if pc.get("type") == prompt_type), None)
def get_agent_data_by_name(agent_name, agent_data):
for data in agent_data:
name = data.get("name", "")
if name == agent_name:
return data
return None
def get_tool_config_by_name(tool_configs, tool_name):
return next((tc for tc in tool_configs if tc.get("name", "") == tool_name), None)
def get_tool_config_by_type(tool_configs, tool_type):
return next((tc for tc in tool_configs if tc.get("type", "") == tool_type), None)

View file

@ -0,0 +1,50 @@
from .access import get_agent_config_by_name, get_agent_data_by_name
from src.graph.types import ControlType
from src.utils.common import common_logger
logger = common_logger
def get_last_agent_name(state, agent_configs, start_agent_name, msg_type, latest_assistant_msg, start_turn_with_start_agent):
default_last_agent_name = state.get("last_agent_name", '')
last_agent_config = get_agent_config_by_name(default_last_agent_name, agent_configs)
specific_agent_data = get_agent_data_by_name(default_last_agent_name, state.get("agent_data", []))
# Overrides for special cases
logger.info("Setting agent control based on last agent and control type")
if msg_type == "tool":
last_agent_name = default_last_agent_name
assert last_agent_name == latest_assistant_msg.get("sender", ''), "Last agent name does not match sender of latest assistant message during tool call handling"
elif start_turn_with_start_agent:
last_agent_name = start_agent_name
else:
control_type = last_agent_config.get("controlType", ControlType.RETAIN.value)
if control_type == ControlType.PARENT_AGENT.value:
last_agent_name = specific_agent_data.get("most_recent_parent_name", None) if specific_agent_data else None
if not last_agent_name:
logger.error("Most recent parent is empty, defaulting to same agent instead")
last_agent_name = default_last_agent_name
elif control_type == ControlType.START_AGENT.value:
last_agent_name = start_agent_name
else:
last_agent_name = default_last_agent_name
if default_last_agent_name != last_agent_name:
logger.info(f"Last agent name changed from {default_last_agent_name} to {last_agent_name} due to control settings")
return last_agent_name
def get_latest_assistant_msg(messages):
# Find the latest message with role assistant
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "assistant":
return messages[i]
return None
def get_latest_non_assistant_messages(messages):
# Find all messages after the last assistant message
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "assistant":
return messages[i+1:]
return messages

View file

@ -0,0 +1,30 @@
from src.graph.instructions import TRANSFER_CHILDREN_INSTRUCTIONS, TRANSFER_PARENT_AWARE_INSTRUCTIONS, RAG_INSTRUCTIONS, ERROR_ESCALATION_AGENT_INSTRUCTIONS, TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS
def add_transfer_instructions_to_parent_agents(agent, children, transfer_functions):
other_agent_name_descriptions_tools = f'\n{'-'*100}\n'.join([f"Name: {agent.name}\nDescription: {agent.description if agent.description else ''}\nTool for transfer: {transfer_functions[agent.name].__name__}" for agent in children.values()])
prompt = TRANSFER_CHILDREN_INSTRUCTIONS.format(other_agent_name_descriptions_tools=other_agent_name_descriptions_tools)
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
return agent
def add_transfer_instructions_to_child_agents(child, children_aware_of_parent):
if children_aware_of_parent:
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join([f"Name: {parent_name}\nTool for transfer: {func.__name__}" for parent_name, func in child.candidate_parent_functions.items()])
prompt = TRANSFER_PARENT_AWARE_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
else:
candidate_parents_name_description_tools = f'\n{'-'*100}\n'.join(list(set([f"Tool for transfer: {func.__name__}" for _, func in child.candidate_parent_functions.items()])))
prompt = TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS.format(candidate_parents_name_description_tools=candidate_parents_name_description_tools)
child.instructions = child.instructions + f'\n\n{'-'*100}\n\n' + prompt
return child
def add_rag_instructions_to_agent(agent_config, rag_tool_name):
prompt = RAG_INSTRUCTIONS.format(rag_tool_name=rag_tool_name)
agent_config["instructions"] = agent_config["instructions"] + f'\n\n{'-'*100}\n\n' + prompt
return agent_config
def add_error_escalation_instructions(agent):
prompt = ERROR_ESCALATION_AGENT_INSTRUCTIONS
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
return agent

View file

@ -0,0 +1,66 @@
from src.utils.common import common_logger
logger = common_logger
from .access import get_agent_data_by_name
def reset_current_turn(messages):
# Set all messages' current_turn to False
for msg in messages:
msg["current_turn"] = False
# Find most recent user message
messages[-1]["current_turn"] = True
return messages
def reset_current_turn_agent_history(agent_data, agent_names):
for name in agent_names:
data = get_agent_data_by_name(name, agent_data)
if data:
for msg in data["history"]:
msg["current_turn"] = False
return agent_data
def add_recent_messages_to_history(recent_messages, last_agent_name, agent_data, messages, parent_has_child_history):
last_msg = messages[-1]
specific_agent_data = get_agent_data_by_name(last_agent_name, agent_data)
if specific_agent_data:
specific_agent_data["history"].extend(recent_messages)
if parent_has_child_history:
current_agent_data = specific_agent_data
while current_agent_data.get("most_recent_parent_name"):
parent_name = current_agent_data.get("most_recent_parent_name")
parent_agent_data = get_agent_data_by_name(parent_name, agent_data)
if parent_agent_data:
parent_agent_data["history"].extend(recent_messages)
current_agent_data = parent_agent_data
else:
logger.error(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
raise ValueError(f"Parent agent data for {current_agent_data['name']} not found in agent_data")
else:
agent_data.append({
"name": last_agent_name,
"history": [last_msg]
})
return agent_data
def construct_state_from_response(response, agents):
agent_data = []
for agent in agents:
agent_data.append({
"name": agent.name,
"instructions": agent.instructions,
"parent_function": agent.parent_function.__name__ if agent.parent_function else None,
"child_functions": [f.__name__ for f in agent.child_functions.values()] if agent.child_functions else [],
"internal_tools": [t.get("function").get("name") for t in agent.internal_tools] if agent.internal_tools else [],
"external_tools": [t.get("function").get("name") for t in agent.external_tools] if agent.external_tools else [],
"history": agent.history,
"most_recent_parent_name": agent.most_recent_parent.name if agent.most_recent_parent else ""
})
state = {
"last_agent_name": response.agent.name,
"agent_data": agent_data
}
return state

View file

@ -0,0 +1,44 @@
from src.utils.common import common_logger
logger = common_logger
def create_transfer_function_to_agent(agent):
agent_name = agent.name
fn_spec = {
"name": f"transfer_to_{agent_name.lower().replace(' ', '_')}",
"description": f"Function to transfer the chat to {agent_name}.",
"return_value": agent
}
def generated_function(*args, **kwargs):
logger.info(f"Transferring chat to {agent_name}")
return fn_spec.get('return_value', None)
generated_function.__name__ = fn_spec['name']
generated_function.__doc__ = fn_spec.get('description', '')
return generated_function
def create_transfer_function_to_parent_agent(parent_agent, children_aware_of_parent, transfer_functions):
if children_aware_of_parent:
name = f"{transfer_functions[parent_agent.name].__name__}_from_child"
description = f"Function to transfer the chat to your parent agent: {parent_agent.name}."
else:
name = "give_up_chat_control"
description = "Function to give up control of the chat when you are unable to handle it."
fn_spec = {
"name": name,
"description": description,
"return_value": parent_agent
}
def generated_function(*args, **kwargs):
logger.info(f"Transferring chat to parent agent: {parent_agent.name}")
return fn_spec.get('return_value', None)
generated_function.__name__ = fn_spec['name']
generated_function.__doc__ = fn_spec.get('description', '')
return generated_function

View file

@ -0,0 +1,61 @@
########################
# Instructions for agents that use RAG
########################
RAG_INSTRUCTIONS = f"""
# Instructions about using the article retrieval tool
- Where relevant, use the articles tool: {{rag_tool_name}} to fetch articles with knowledge relevant to the query and use its contents to respond to the user.
- Do not send a separate message first asking the user to wait while you look up information. Immediately fetch the articles and respond to the user with the answer to their query.
- Do not make up information. If the article's contents do not have the answer, give up control of the chat (or transfer to your parent agent, as per your transfer instructions). Do not say anything to the user.
"""
########################
# Instructions for child agents that are aware of parent agents
########################
TRANSFER_PARENT_AWARE_INSTRUCTIONS = f"""
# Instructions about using your parent agents
You have the following candidate parent agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
{{candidate_parents_name_description_tools}}.
## Notes:
- During runtime, you will be provided with a tool call for exactly one of these parent agents that you can use. Use that tool call to transfer the chat to the parent agent in case you are unable to handle the chat (e.g. if it is not in your scope of instructions).
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
"""
########################
# Instructions for child agents that give up control to parent agents
########################
TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS = f"""
# Instructions about giving up chat control
If you are unable to handle the chat (e.g. if it is not in your scope of instructions), you should use the tool call provided to give up control of the chat.
{{candidate_parents_name_description_tools}}
## Notes:
- When you give up control of the chat, you should not provide any response to the user. Just invoke the tool call to give up control.
"""
########################
# Instructions for parent agents that need to transfer the chat to other specialized (children) agents
########################
TRANSFER_CHILDREN_INSTRUCTIONS = f"""
# Instructions about using other specialized agents
You have the following specialized agents that you can transfer the chat to, using the appropriate tool calls for the transfer:
{{other_agent_name_descriptions_tools}}
## Notes:
- Transfer the chat to the appropriate agent, based on the chat history and / or the user's request.
- When you transfer the chat to another agent, you should not provide any response to the user. For example, do not say 'Transferring chat to X agent' or anything like that. Just invoke the tool call to transfer to the other agent.
- Do NOT ever mention the existence of other agents. For example, do not say 'Please check with X agent for details regarding processing times.' or anything like that.
- If any other agent transfers the chat to you without responding to the user, it means that they don't know how to help. Do not transfer the chat to back to the same agent in this case. In such cases, you should transfer to the escalation agent using the appropriate tool call. Never ask the user to contact support.
"""
########################
# Additional instruction for escalation agent when called due to an error
########################
ERROR_ESCALATION_AGENT_INSTRUCTIONS = f"""
# Context
The rest of the parts of the chatbot were unable to handle the chat. Hence, the chat has been escalated to you. In addition to your other instructions, tell the user that you are having trouble handling the chat - say "I'm having trouble helping with your request. Sorry about that.". Remember you are a part of the chatbot as well.
"""

View file

@ -0,0 +1,81 @@
import json
import random
from src.utils.common import common_logger
logger = common_logger
RAG_TOOL = {
"name": "getArticleInfo",
"type": "rag",
"description": "Fetch articles with knowledge relevant to the query",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The query to retrieve articles for"
}
},
"required": [
"query"
]
}
}
CLOSE_CHAT_TOOL = {
"name": "close_chat",
"type": "close_chat",
"description": "Close the chat",
"parameters": {
"type": "object",
"properties": {
"error_message": {
"type": "string", "description": "The error message to close the chat with"
}
}
}
}
def tool_raise_error(error_message):
logger.error(f"Raising error: {error_message}")
raise ValueError(f"Raising error: {error_message}")
def respond_to_tool_raise_error(tool_calls, mock=False):
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
return _create_tool_response(tool_calls, tool_raise_error(error_message))
def tool_close_chat(error_message):
logger.error(f"Closing chat: {error_message}")
raise ValueError(f"Closing chat: {error_message}")
def respond_to_tool_close_chat(tool_calls, mock=False):
error_message = json.loads(tool_calls[0]["function"]["arguments"]).get("error_message", "")
return _create_tool_response(tool_calls, tool_close_chat(error_message))
def _create_tool_response(tool_calls, content, mock=False):
"""
Creates a standardized tool response format.
"""
return {
"role": "tool",
"content": content,
"tool_call_id": tool_calls[0]["id"],
"name": tool_calls[0]["function"]["name"]
}
def create_error_tool_call(error_message):
error_message_tool_call = {
"role": "assistant",
"sender": "system",
"tool_calls": [
{
"function": {
"name": "raise_error",
"arguments": "{\"error_message\":\"" + error_message + "\"}"
},
"id": "call_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=24)),
"type": "function"
}
]
}
return error_message_tool_call

View file

@ -0,0 +1,18 @@
from enum import Enum
class AgentRole(Enum):
ESCALATION = "escalation"
POST_PROCESSING = "post_process"
GUARDRAILS = "guardrails"
class ControlType(Enum):
RETAIN = "retain"
PARENT_AGENT = "relinquish_to_parent"
START_AGENT = "relinquish_to_start"
class PromptType(Enum):
STYLE = "style_prompt"
class ErrorType(Enum):
FATAL = "fatal"
ESCALATE = "escalate"

View file

@ -0,0 +1,4 @@
from .core import Swarm
from .types import Agent, Response
__all__ = ["Swarm", "Agent", "Response"]

View file

@ -0,0 +1,269 @@
# Standard library imports
import copy
import json
from collections import defaultdict
from typing import List, Callable, Union
from datetime import datetime
# Package/library imports
from openai import OpenAI
import random
# Local imports
from .util import *
from .types import (
Agent,
AgentFunction,
ChatCompletionMessage,
ChatCompletionMessageToolCall,
Function,
Response,
Result,
)
__CTX_VARS_NAME__ = "context_variables"
class Swarm:
def __init__(self, client=None):
if not client:
client = OpenAI(api_key=OPENAI_API_KEY)
self.client = client
self.history = defaultdict(lambda : [])
def get_chat_completion(
self,
agent: Agent,
history: List,
context_variables: dict,
model_override: str,
stream: bool,
debug: bool,
) -> ChatCompletionMessage:
context_variables = defaultdict(str, context_variables)
instructions = (
agent.instructions(context_variables)
if callable(agent.instructions)
else agent.instructions
)
messages = [{"role": "system", "content": instructions}] + history
debug_print(debug, "Getting chat completion for...:", messages)
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
all_tools = agent.external_tools + agent.internal_tools
funcs_and_tools = [function_to_json(f) for f in all_functions] + [t for t in all_tools]
# hide context_variables from model
for tool in funcs_and_tools:
params = tool["function"]["parameters"]
params["properties"].pop(__CTX_VARS_NAME__, None)
if __CTX_VARS_NAME__ in params["required"]:
params["required"].remove(__CTX_VARS_NAME__)
create_params = {
"model": model_override or agent.model,
"messages": messages,
"tools": funcs_and_tools or None,
"tool_choice": agent.tool_choice,
"stream": stream,
}
if funcs_and_tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
return self.client.chat.completions.create(**create_params)
def handle_function_result(self, result, debug) -> Result:
match result:
case Result() as result:
return result
case Agent() as agent:
return Result(
value=json.dumps({"assistant": agent.name}),
agent=agent,
)
case _:
try:
return Result(value=str(result))
except Exception as e:
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
debug_print(debug, error_message)
raise TypeError(error_message)
def handle_function_calls(
self,
tool_calls: List[ChatCompletionMessageToolCall],
functions: List[AgentFunction],
context_variables: dict,
debug: bool,
) -> Response:
function_map = {f.__name__: f for f in functions}
partial_response = Response(
messages=[], agent=None, context_variables={})
for tool_call in tool_calls:
name = tool_call.function.name
# handle missing tool case, skip to next tool
if name not in function_map:
debug_print(debug, f"Tool {name} not found in function map.")
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": f"Error: Tool {name} not found.",
}
)
continue
args = json.loads(tool_call.function.arguments)
debug_print(
debug, f"Processing tool call: {name} with arguments {args}")
func = function_map[name]
# pass context_variables to agent functions
if __CTX_VARS_NAME__ in func.__code__.co_varnames:
args[__CTX_VARS_NAME__] = context_variables
raw_result = function_map[name](**args)
result: Result = self.handle_function_result(raw_result, debug)
partial_response.messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
}
)
partial_response.context_variables.update(result.context_variables)
if result.agent:
partial_response.agent = result.agent
return partial_response
def run(
self,
agent: Agent,
messages: List,
context_variables: dict = {},
model_override: str = None,
stream: bool = False,
debug: bool = False,
max_messages_per_turn: int = 10,
execute_tools: bool = True,
external_tools: List[str] = [],
localize_history: bool = True,
parent_has_child_history: bool = True,
tokens_used: dict = {}
) -> Response:
active_agent = agent
context_variables = copy.deepcopy(context_variables)
global_history = copy.deepcopy(messages)
init_len = len(messages)
while len(global_history) - init_len < max_messages_per_turn and active_agent:
history = active_agent.history if localize_history else global_history
history = arrange_messages_keys_in_order(history)
parent = active_agent.most_recent_parent
children_names_backup, children_backup, child_functions_backup = copy.deepcopy(active_agent.children_names), copy.deepcopy(active_agent.children), copy.deepcopy(active_agent.child_functions)
active_agent = check_and_remove_repeat_tool_call_to_child(active_agent, history)
# get completion with current history, agent
completion = self.get_chat_completion(
agent=active_agent,
history=history,
context_variables=context_variables,
model_override=model_override,
stream=stream,
debug=debug,
)
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)
# Restore children and child functions
active_agent.children_names, active_agent.children, active_agent.child_functions = children_names_backup, children_backup, child_functions_backup
message = completion.choices[0].message
debug_print(debug, "Received completion:", message)
message.sender = active_agent.name
message_json = json.loads(message.model_dump_json())
message_json = add_message_metadata(message_json, active_agent)
if localize_history:
active_agent = update_histories(active_agent, message_json)
if parent and parent_has_child_history:
parent = update_histories(parent, message_json)
global_history.append(message_json)
external_tool_calls = []
internal_tool_calls = []
if message.tool_calls:
message_json["response_type"] = "internal"
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
if tool_name in external_tools:
external_tool_calls.append(tool_call)
else:
internal_tool_calls.append(tool_call)
message.tool_calls = internal_tool_calls
if not message.tool_calls or not execute_tools:
if external_tool_calls:
message.tool_calls.extend(external_tool_calls)
debug_print(debug, "Ending turn.")
break
# handle function calls, updating context_variables, and switching agents
all_functions = list(active_agent.child_functions.values()) + ([active_agent.parent_function] if active_agent.parent_function else [])
partial_response = self.handle_function_calls(
message.tool_calls, all_functions, context_variables, debug
)
for msg in partial_response.messages:
msg = add_message_metadata(msg, active_agent)
if localize_history:
active_agent = update_histories(active_agent, msg)
if parent and parent_has_child_history:
parent = update_histories(parent, msg)
global_history.extend(partial_response.messages)
context_variables.update(partial_response.context_variables)
# Parent to child transfer
if partial_response.agent:
prev_agent = active_agent
active_agent = partial_response.agent
# Parent to child transfer
if active_agent.name in prev_agent.children_names:
active_agent.most_recent_parent = prev_agent
active_agent.parent_function = active_agent.candidate_parent_functions[active_agent.most_recent_parent.name]
if localize_history:
if not parent_has_child_history:
prev_agent.history = remove_irrelevant_messages(prev_agent.history)
new_active_agent_history = get_current_turn_messages(global_history, only_user = True)
active_agent.history.extend(new_active_agent_history)
# Child to parent transfer
else:
assert parent == active_agent, "Parent and active agent do not match when active agent is not a child of previous agent"
child = prev_agent
if localize_history:
child.history = remove_irrelevant_messages(child.history)
return_messages = global_history[init_len:]
error_msg = ""
if len(global_history) - init_len >= max_messages_per_turn:
error_msg = "Max messages per turn reached"
return Response(
messages=return_messages,
agent=active_agent,
context_variables=context_variables,
error_msg=error_msg,
tokens_used=tokens_used
)

View file

@ -0,0 +1 @@
from .repl import run_demo_loop

View file

@ -0,0 +1,87 @@
import json
from swarm import Swarm
def process_and_print_streaming_response(response):
content = ""
last_sender = ""
for chunk in response:
if "sender" in chunk:
last_sender = chunk["sender"]
if "content" in chunk and chunk["content"] is not None:
if not content and last_sender:
print(f"\033[94m{last_sender}:\033[0m", end=" ", flush=True)
last_sender = ""
print(chunk["content"], end="", flush=True)
content += chunk["content"]
if "tool_calls" in chunk and chunk["tool_calls"] is not None:
for tool_call in chunk["tool_calls"]:
f = tool_call["function"]
name = f["name"]
if not name:
continue
print(f"\033[94m{last_sender}: \033[95m{name}\033[0m()")
if "delim" in chunk and chunk["delim"] == "end" and content:
print() # End of response message
content = ""
if "response" in chunk:
return chunk["response"]
def pretty_print_messages(messages) -> None:
for message in messages:
if message["role"] != "assistant":
continue
# print agent name in blue
print(f"\033[94m{message['sender']}\033[0m:", end=" ")
# print response, if any
if message["content"]:
print(message["content"])
# print tool calls in purple, if any
tool_calls = message.get("tool_calls") or []
if len(tool_calls) > 1:
print()
for tool_call in tool_calls:
f = tool_call["function"]
name, args = f["name"], f["arguments"]
arg_str = json.dumps(json.loads(args)).replace(":", "=")
print(f"\033[95m{name}\033[0m({arg_str[1:-1]})")
def run_demo_loop(
starting_agent, context_variables=None, stream=False, debug=False
) -> None:
client = Swarm()
print("Starting Swarm CLI 🐝")
messages = []
agent = starting_agent
while True:
user_input = input("\033[90mUser\033[0m: ")
messages.append({"role": "user", "content": user_input})
response = client.run(
agent=agent,
messages=messages,
context_variables=context_variables or {},
stream=stream,
debug=debug,
)
if stream:
response = process_and_print_streaming_response(response)
else:
pretty_print_messages(response.messages)
messages.extend(response.messages)
agent = response.agent

View file

@ -0,0 +1,54 @@
from __future__ import annotations
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from typing import List, Callable, Union, Optional, Dict
# Third-party imports
from pydantic import BaseModel
AgentFunction = Callable[[], Union[str, "Agent", dict]]
class Agent(BaseModel):
name: str = "Agent"
model: str = "gpt-4o"
type: str = ""
instructions: Union[str, Callable[[], str]] = "You are a helpful agent.",
description: str = "This is a helpful agent."
candidate_parent_functions: Dict[str, AgentFunction] = {}
parent_function: AgentFunction = None
child_functions: Dict[str, AgentFunction] = {}
internal_tools: List[Dict] = []
external_tools: List[Dict] = []
tool_choice: str = None
parallel_tool_calls: bool = True
respond_to_user: bool = True
history: List[Dict] = []
children_names: List[str] = []
children: Dict[str, "Agent"] = {}
most_recent_parent: Optional["Agent"] = None
parent: "Agent" = None
class Response(BaseModel):
messages: List = []
agent: Optional[Agent] = None
context_variables: dict = {}
error_msg: Optional[str] = ""
tokens_used: dict = {}
class Result(BaseModel):
"""
Encapsulates the possible return values for an agent function.
Attributes:
value (str): The result value as a string.
agent (Agent): The agent instance, if applicable.
context_variables (dict): A dictionary of context variables.
"""
value: str = ""
agent: Optional[Agent] = None
context_variables: dict = {}

View file

@ -0,0 +1,175 @@
import inspect
import json
from datetime import datetime
import os
from dotenv import load_dotenv
from src.utils.common import read_json_from_file, get_api_key
load_dotenv()
OPENAI_API_KEY = get_api_key("OPENAI_API_KEY")
def debug_print(debug: bool, *args: str) -> None:
if not debug:
return
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message = " ".join(map(str, args))
print(f"\033[97m[\033[90m{timestamp}\033[97m]\033[90m {message}\033[0m")
def merge_fields(target, source):
for key, value in source.items():
if isinstance(value, str):
target[key] += value
elif value is not None and isinstance(value, dict):
merge_fields(target[key], value)
def merge_chunk(final_response: dict, delta: dict) -> None:
delta.pop("role", None)
merge_fields(final_response, delta)
tool_calls = delta.get("tool_calls")
if tool_calls and len(tool_calls) > 0:
index = tool_calls[0].pop("index")
merge_fields(final_response["tool_calls"][index], tool_calls[0])
def function_to_json(func) -> dict:
"""
Converts a Python function into a JSON-serializable dictionary
that describes the function's signature, including its name,
description, and parameters.
Args:
func: The function to be converted.
Returns:
A dictionary representing the function's signature in JSON format.
"""
type_map = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
type(None): "null",
}
try:
signature = inspect.signature(func)
except ValueError as e:
raise ValueError(
f"Failed to get signature for function {func.__name__}: {str(e)}"
)
parameters = {}
for param in signature.parameters.values():
try:
param_type = type_map.get(param.annotation, "string")
except KeyError as e:
raise KeyError(
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
)
parameters[param.name] = {"type": param_type}
required = [
param.name
for param in signature.parameters.values()
if param.default == inspect._empty
]
return {
"type": "function",
"function": {
"name": func.__name__,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": parameters,
"required": required,
},
},
}
def get_current_turn_messages(messages, only_user = False):
if only_user:
return [msg for msg in messages if msg.get("current_turn") and msg.get("role") == "user"]
else:
return [msg for msg in messages if msg.get("current_turn")]
def arrange_messages_keys_in_order(messages):
"""Arranges message keys in a specific order: id, role, sender, relevant_agents, content, created_at, timestamp, followed by rest alphabetically"""
key_order = ['role', 'sender', 'content', 'created_at']
def sort_keys(message):
# Create new dict with specified key order
ordered = {}
# Add keys in specified order if they exist
for key in key_order:
if key in message:
ordered[key] = message[key]
# Add remaining keys in alphabetical order
for key in sorted(message.keys()):
if key not in key_order:
ordered[key] = message[key]
return ordered
return [sort_keys(message) for message in messages]
def remove_irrelevant_messages(messages):
"""Removes all messages from and including the latest user message"""
for i in range(len(messages)-1, -1, -1):
if messages[i].get("role") == "user":
return messages[:i]
return messages
def update_histories(active_agent, message):
active_agent.history.append(message)
return active_agent
def remove_none_fields(message):
return {k: v for k, v in message.items() if v is not None}
def add_message_metadata(message, active_agent):
message = remove_none_fields(message)
message["created_at"] = datetime.now().isoformat()
message["current_turn"] = True
if active_agent.respond_to_user:
message["response_type"] = "external"
else:
message["response_type"] = "internal"
return message
def check_and_remove_repeat_tool_call_to_child(agent, messages):
# If in the current turn, the most recent assistant message (need not be the last message overall, just needs to be the last message with role as assistant) is a tool call from a child agent, which transfers control to the agent using its parent function, then remove the tool call to transfer to that child again from this agent. This is to prevent back and forth between this agent and the child agent.
for message in reversed(messages):
if message.get("role") == "assistant" and message.get("sender") in agent.children_names and message.get("tool_calls"):
tool_call = message.get("tool_calls")[0]
child_agent = agent.children.get(message.get("sender"), None)
if not child_agent:
continue
child_agent_name = child_agent.name
if tool_call.get("function").get("name") == child_agent.parent_function:
agent.children_names.remove(child_agent_name)
agent.children.pop(child_agent_name)
agent.child_functions.pop(child_agent_name)
break
return agent
def update_tokens_used(provider, model, tokens_used, completion):
provider_model = f"{provider}/{model}"
input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
if provider_model not in tokens_used:
tokens_used[provider_model] = {
'input_tokens': 0,
'output_tokens': 0,
}
tokens_used[provider_model]['input_tokens'] += input_tokens
tokens_used[provider_model]['output_tokens'] += output_tokens
return tokens_used

View file

View file

@ -0,0 +1,201 @@
import json
import logging
import os
import subprocess
import sys
import time
from collections import defaultdict
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
def setup_logger(name, log_file='./run.log', level=logging.INFO, log_to_file=True):
"""Function to set up a logger with a specific name and log file."""
formatter = logging.Formatter('%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s')
if log_to_file:
handler = logging.FileHandler(log_file)
else:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
# Create a logger and set its level
logger = logging.getLogger(name)
logger.setLevel(level)
# Clear any existing handlers to avoid duplicates
if logger.hasHandlers():
logger.handlers.clear()
# Prevent propagation to parent loggers
logger.propagate = False
logger.addHandler(handler)
return logger
common_logger = setup_logger('logger')
logger = common_logger
def read_json_from_file(file_name):
logger.info(f"Reading json from {file_name}")
try:
with open(file_name, 'r') as file:
out = file.read()
out = json.loads(out)
return out
except Exception as e:
logger.error(e)
return None
def get_api_key(key_name):
api_key = os.getenv(key_name)
# Check if the API key was loaded successfully
if not api_key:
raise ValueError(f"{key_name} not found. Did you set it in the .env file?")
return api_key
openai_client = OpenAI(
api_key=get_api_key("OPENAI_API_KEY")
)
def generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json', model="gpt-4o"):
return generate_openai_output(messages, output_type, model)
def generate_openai_output(messages, output_type='not_json', model="gpt-4o", return_completion=False):
try:
if output_type == 'json':
chat_completion = openai_client.chat.completions.create(
messages=messages,
model=model,
response_format={"type": "json_object"}
)
else:
chat_completion = openai_client.chat.completions.create(
messages=messages,
model=model,
)
if return_completion:
return chat_completion
return chat_completion.choices[0].message.content
except Exception as e:
logger.error(e)
return None
def generate_llm_output(messages, model):
model_provider = None
if "gpt" in model:
model_provider = "openai"
else:
raise ValueError(f"Model {model} not supported")
if model_provider == "openai":
response = generate_openai_output(messages, output_type='text', model=model)
return response
def generate_gpt4o_output_from_multi_turn_conv_multithreaded(messages, retries=5, delay=1, output_type='json'):
while retries > 0:
try:
# Call GPT-4o API
output = generate_gpt4o_output_from_multi_turn_conv(messages, output_type='json')
return output # If the request is successful, break out of the loop
except openai.RateLimitError:
print(f'Rate limit exceeded. Retrying in {delay} seconds...')
time.sleep(delay)
delay *= 2 # Exponential backoff
retries -= 1
if retries == 0:
print(f'Failed to process due to rate limit.')
return []
def convert_message_content_json_to_strings(messages):
for msg in messages:
if 'content' in msg.keys() and isinstance(msg['content'], dict):
msg['content'] = json.dumps(msg['content'])
return messages
def merge_defaultdicts(dict_parent, dict_child):
for key, value in dict_child.items():
if key in dict_parent:
# If the key exists in both, handle merging based on type
if isinstance(dict_parent[key], list):
dict_parent[key].extend(value)
elif isinstance(dict_parent[key], dict):
dict_parent[key].update(value)
elif isinstance(dict_parent[key], set):
dict_parent[key].update(value)
else:
dict_parent[key] += value # For other types like int, float, etc.
else:
dict_parent[key] = value
return dict_parent
def read_jsonl_from_file(file_name):
# logger.info(f"Reading jsonl from {file_name}")
try:
with open(file_name, 'r') as file:
lines = file.readlines()
dataset = [json.loads(line.strip()) for line in lines]
return dataset
except Exception as e:
logger.error(e)
return None
def write_jsonl_to_file(list_dicts, file_name):
try:
with open(file_name, 'w') as file:
for d in list_dicts:
file.write(json.dumps(d)+'\n')
return True
except Exception as e:
logger.error(e)
return False
def read_text_from_file(file_name):
try:
with open(file_name, 'r') as file:
out = file.read()
return out
except Exception as e:
logger.error(e)
return None
def write_json_to_file(data, file_name):
try:
with open(file_name, 'w') as file:
json.dump(data, file, indent=4)
return True
except Exception as e:
logger.error(e)
return False
def get_git_path(path):
# Run `git rev-parse --show-toplevel` to get the root of the Git repository
try:
git_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
return f"{git_root}/{path}"
except subprocess.CalledProcessError:
raise RuntimeError("Not inside a Git repository")
def update_tokens_used(provider, model, tokens_used, completion):
provider_model = f"{provider}/{model}"
input_tokens = completion.usage.prompt_tokens
output_tokens = completion.usage.completion_tokens
if provider_model not in tokens_used:
tokens_used[provider_model] = {
'input_tokens': 0,
'output_tokens': 0,
}
tokens_used[provider_model]['input_tokens'] += input_tokens
tokens_used[provider_model]['output_tokens'] += output_tokens
return tokens_used