Add support for system message and updated readme

This commit is contained in:
akhisud3195 2025-01-15 16:03:22 +05:30
parent a22d54fd3c
commit 642b6dc91d
10 changed files with 62 additions and 28 deletions

View file

@ -1,8 +1,9 @@
from flask import Flask, request, jsonify
from src.graph.core import run_turn
from datetime import datetime
from src.graph.core import run_turn
from src.graph.tools import RAG_TOOL, CLOSE_CHAT_TOOL
from src.utils.common import common_logger, read_json_from_file
logger = common_logger

View file

@ -1,5 +1,6 @@
import os
import sys
from copy import deepcopy
from src.swarm.types import Agent
from src.swarm.core import Swarm
@ -10,11 +11,10 @@ from .types import AgentRole, PromptType, ErrorType
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from src.utils.common import common_logger
from copy import deepcopy
logger = common_logger
def order_messages(messages):
@ -305,6 +305,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
latest_assistant_msg = get_latest_assistant_msg(messages)
universal_sys_msg = get_universal_system_message(messages)
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
msg_type = latest_non_assistant_msgs[-1]["role"]
@ -384,10 +385,10 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
logger.info(f"Found {len(external_tools)} external tools")
logger.debug("Initializing Swarm client")
client = Swarm()
swarm_client = Swarm()
if not validation_error_msg:
response = client.run(
response = swarm_client.run(
agent=last_agent,
messages=messages,
execute_tools=True,
@ -395,7 +396,8 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used
tokens_used=tokens_used,
universal_sys_msg=universal_sys_msg
)
tokens_used = response.tokens_used
last_agent = response.agent

View file

@ -1,4 +1,4 @@
from src.graph.instructions import TRANSFER_CHILDREN_INSTRUCTIONS, TRANSFER_PARENT_AWARE_INSTRUCTIONS, RAG_INSTRUCTIONS, ERROR_ESCALATION_AGENT_INSTRUCTIONS, TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS
from src.graph.instructions import TRANSFER_CHILDREN_INSTRUCTIONS, TRANSFER_PARENT_AWARE_INSTRUCTIONS, RAG_INSTRUCTIONS, ERROR_ESCALATION_AGENT_INSTRUCTIONS, TRANSFER_GIVE_UP_CONTROL_INSTRUCTIONS, SYSTEM_MESSAGE
def add_transfer_instructions_to_parent_agents(agent, children, transfer_functions):
other_agent_name_descriptions_tools = f'\n{'-'*100}\n'.join([f"Name: {agent.name}\nDescription: {agent.description if agent.description else ''}\nTool for transfer: {transfer_functions[agent.name].__name__}" for agent in children.values()])
@ -27,4 +27,9 @@ def add_rag_instructions_to_agent(agent_config, rag_tool_name):
def add_error_escalation_instructions(agent):
prompt = ERROR_ESCALATION_AGENT_INSTRUCTIONS
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + prompt
return agent
return agent
def get_universal_system_message(messages):
if messages and messages[0].get("role") == "system":
return SYSTEM_MESSAGE.format(system_message=messages[0].get("content"))
return ""

View file

@ -58,4 +58,13 @@ You have the following specialized agents that you can transfer the chat to, usi
ERROR_ESCALATION_AGENT_INSTRUCTIONS = f"""
# Context
The rest of the parts of the chatbot were unable to handle the chat. Hence, the chat has been escalated to you. In addition to your other instructions, tell the user that you are having trouble handling the chat - say "I'm having trouble helping with your request. Sorry about that.". Remember you are a part of the chatbot as well.
"""
########################
# Universal system message formatting
########################
SYSTEM_MESSAGE = f"""
# Additional System-Wide Context or Instructions:
{{system_message}}
"""

View file

@ -39,6 +39,7 @@ class Swarm:
model_override: str,
stream: bool,
debug: bool,
universal_sys_msg: str,
) -> ChatCompletionMessage:
context_variables = defaultdict(str, context_variables)
instructions = (
@ -46,7 +47,7 @@ class Swarm:
if callable(agent.instructions)
else agent.instructions
)
messages = [{"role": "system", "content": instructions}] + history
messages = [{"role": "system", "content": instructions + universal_sys_msg}] + history
debug_print(debug, "Getting chat completion for...:", messages)
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
@ -73,22 +74,24 @@ class Swarm:
return self.client.chat.completions.create(**create_params)
def handle_function_result(self, result, debug) -> Result:
match result:
case Result() as result:
return result
case Agent() as agent:
return Result(
value=json.dumps({"assistant": agent.name}),
agent=agent,
)
case _:
try:
return Result(value=str(result))
except Exception as e:
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
debug_print(debug, error_message)
raise TypeError(error_message)
# Check if result is already a Result instance
if isinstance(result, Result):
return result
# Check if result is an Agent instance
if isinstance(result, Agent):
return Result(
value=json.dumps({"assistant": result.name}),
agent=result,
)
# Handle all other cases
try:
return Result(value=str(result))
except Exception as e:
error_message = f"Failed to cast response to string: {result}. Make sure agent functions return a string or Result object. Error: {str(e)}"
debug_print(debug, error_message)
raise TypeError(error_message)
def handle_function_calls(
self,
@ -153,7 +156,8 @@ class Swarm:
external_tools: List[str] = [],
localize_history: bool = True,
parent_has_child_history: bool = True,
tokens_used: dict = {}
tokens_used: dict = {},
universal_sys_msg: str = '',
) -> Response:
active_agent = agent
@ -179,6 +183,7 @@ class Swarm:
model_override=model_override,
stream=stream,
debug=debug,
universal_sys_msg=universal_sys_msg,
)
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)