Move system message additions outside swarm

This commit is contained in:
akhisud3195 2025-01-24 22:48:44 +05:30 committed by ramnique
parent c1db2b7306
commit 27fd1a069a
3 changed files with 15 additions and 11 deletions

View file

@ -11,7 +11,7 @@ from .types import AgentRole, PromptType, ErrorType
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
from src.utils.common import common_logger
@ -50,7 +50,7 @@ def clear_agent_fields(agent):
return agent
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent):
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent, universal_sys_msg):
# Create Agent objects
agents = []
@ -183,6 +183,9 @@ def get_agents(agent_configs, tool_configs, localize_history, available_tool_map
if agent.most_recent_parent:
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
for agent in agents:
agent = add_universal_system_message_to_agent(agent, universal_sys_msg)
return agents
@ -342,6 +345,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
localize_history=localize_history,
start_turn_with_start_agent=start_turn_with_start_agent,
children_aware_of_parent=children_aware_of_parent,
universal_sys_msg=universal_sys_msg
)
if not all_agents:
logger.error("No agents initialized")
@ -396,8 +400,7 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
localize_history=localize_history,
parent_has_child_history=parent_has_child_history,
max_messages_per_turn=max_messages_per_turn,
tokens_used=tokens_used,
universal_sys_msg=universal_sys_msg
tokens_used=tokens_used
)
tokens_used = response.tokens_used
last_agent = response.agent

View file

@ -32,4 +32,8 @@ def add_error_escalation_instructions(agent):
def get_universal_system_message(messages):
if messages and messages[0].get("role") == "system":
return SYSTEM_MESSAGE.format(system_message=messages[0].get("content"))
return ""
return ""
def add_universal_system_message_to_agent(agent, universal_sys_msg):
agent.instructions = agent.instructions + f'\n\n{'-'*100}\n\n' + universal_sys_msg
return agent

View file

@ -39,7 +39,6 @@ class Swarm:
model_override: str,
stream: bool,
debug: bool,
universal_sys_msg: str,
) -> ChatCompletionMessage:
context_variables = defaultdict(str, context_variables)
instructions = (
@ -47,7 +46,7 @@ class Swarm:
if callable(agent.instructions)
else agent.instructions
)
messages = [{"role": "system", "content": instructions + universal_sys_msg}] + history
messages = [{"role": "system", "content": instructions}] + history
debug_print(debug, "Getting chat completion for...:", messages)
all_functions = list(agent.child_functions.values()) + ([agent.parent_function] if agent.parent_function else [])
@ -156,8 +155,7 @@ class Swarm:
external_tools: List[str] = [],
localize_history: bool = True,
parent_has_child_history: bool = True,
tokens_used: dict = {},
universal_sys_msg: str = '',
tokens_used: dict = {}
) -> Response:
active_agent = agent
@ -182,8 +180,7 @@ class Swarm:
context_variables=context_variables,
model_override=model_override,
stream=stream,
debug=debug,
universal_sys_msg=universal_sys_msg,
debug=debug
)
tokens_used = update_tokens_used(provider="openai", model=model_override or active_agent.model, tokens_used=tokens_used, completion=completion)