mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-06-06 19:35:44 +02:00
476 lines
21 KiB
Python
476 lines
21 KiB
Python
|
|
import os
|
||
|
|
import sys
|
||
|
|
|
||
|
|
from src.swarm.types import Agent
|
||
|
|
from src.swarm.core import Swarm
|
||
|
|
|
||
|
|
from .guardrails import post_process_response
|
||
|
|
from .tools import create_error_tool_call
|
||
|
|
from .types import AgentRole, PromptType, ErrorType
|
||
|
|
from .helpers.access import get_agent_data_by_name, get_agent_by_name, get_agent_config_by_name, get_tool_config_by_name, get_tool_config_by_type, get_external_tools, get_prompt_by_type, pop_agent_config_by_type, get_agent_by_type
|
||
|
|
from .helpers.transfer import create_transfer_function_to_agent, create_transfer_function_to_parent_agent
|
||
|
|
from .helpers.state import add_recent_messages_to_history, construct_state_from_response, reset_current_turn, reset_current_turn_agent_history
|
||
|
|
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions
|
||
|
|
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
|
||
|
|
|
||
|
|
from src.utils.common import common_logger
|
||
|
|
from copy import deepcopy
|
||
|
|
logger = common_logger
|
||
|
|
|
||
|
|
def order_messages(messages):
|
||
|
|
# Arrange keys in specified order
|
||
|
|
ordered_messages = []
|
||
|
|
for msg in messages:
|
||
|
|
ordered = {}
|
||
|
|
msg = {k: v for k, v in msg.items() if v is not None}
|
||
|
|
# Add keys in specified order if they exist
|
||
|
|
for key in ['role', 'sender', 'content', 'created_at', 'timestamp']:
|
||
|
|
if key in msg:
|
||
|
|
ordered[key] = msg[key]
|
||
|
|
# Add remaining keys in alphabetical order
|
||
|
|
for key in sorted(msg.keys()):
|
||
|
|
if key not in ['role', 'sender', 'content', 'created_at', 'timestamp']:
|
||
|
|
ordered[key] = msg[key]
|
||
|
|
ordered_messages.append(ordered)
|
||
|
|
|
||
|
|
return ordered_messages
|
||
|
|
|
||
|
|
def clean_up_history(agent_data):
|
||
|
|
for data in agent_data:
|
||
|
|
data["history"] = order_messages(data["history"])
|
||
|
|
return agent_data
|
||
|
|
|
||
|
|
def clear_agent_fields(agent):
|
||
|
|
agent.children = {}
|
||
|
|
agent.parent_function = None
|
||
|
|
agent.candidate_parent_functions = {}
|
||
|
|
agent.child_functions = {}
|
||
|
|
if agent.most_recent_parent:
|
||
|
|
agent.history = []
|
||
|
|
|
||
|
|
return agent
|
||
|
|
|
||
|
|
def get_agents(agent_configs, tool_configs, localize_history, available_tool_mappings, agent_data, start_turn_with_start_agent, children_aware_of_parent):
|
||
|
|
# Create Agent objects
|
||
|
|
agents = []
|
||
|
|
|
||
|
|
if not isinstance(agent_configs, list):
|
||
|
|
raise ValueError("Agents config is not a list in get_agents")
|
||
|
|
|
||
|
|
if not isinstance(tool_configs, list):
|
||
|
|
raise ValueError("Tools config is not a list in get_agents")
|
||
|
|
|
||
|
|
for agent_config in agent_configs:
|
||
|
|
logger.debug(f"Processing config for agent: {agent_config['name']}")
|
||
|
|
|
||
|
|
# Get tools for this agent
|
||
|
|
external_tools = []
|
||
|
|
internal_tools = []
|
||
|
|
candidate_parent_functions = {}
|
||
|
|
child_functions = {}
|
||
|
|
|
||
|
|
logger.debug(f"Finding tools for agent {agent_config['name']}")
|
||
|
|
logger.debug(f"Agent {agent_config['name']} has {len(agent_config['tools'])} configured tools")
|
||
|
|
|
||
|
|
if agent_config.get("hasRagSources", False):
|
||
|
|
rag_tool_name = get_tool_config_by_type(tool_configs, "rag").get("name", "")
|
||
|
|
agent_config["tools"].append(rag_tool_name)
|
||
|
|
agent_config = add_rag_instructions_to_agent(agent_config, rag_tool_name)
|
||
|
|
|
||
|
|
for tool_name in agent_config["tools"]:
|
||
|
|
logger.debug(f"Looking for tool config: {tool_name}")
|
||
|
|
tool_config = get_tool_config_by_name(tool_configs, tool_name)
|
||
|
|
if tool_config:
|
||
|
|
if tool_name in available_tool_mappings:
|
||
|
|
internal_tools.append(available_tool_mappings[tool_name])
|
||
|
|
else:
|
||
|
|
external_tools.append({
|
||
|
|
"type": "function",
|
||
|
|
"function": tool_config
|
||
|
|
})
|
||
|
|
logger.debug(f"Added tool {tool_name} to agent {agent_config['name']}")
|
||
|
|
else:
|
||
|
|
logger.warning(f"Tool {tool_name} not found in tool_configs")
|
||
|
|
|
||
|
|
history = []
|
||
|
|
this_agent_data = get_agent_data_by_name(agent_config["name"], agent_data)
|
||
|
|
if this_agent_data:
|
||
|
|
if localize_history:
|
||
|
|
history = this_agent_data.get("history", [])
|
||
|
|
|
||
|
|
# Create agent
|
||
|
|
logger.debug(f"Creating Agent object for {agent_config['name']}")
|
||
|
|
logger.debug(f"Using model: {agent_config['model']}")
|
||
|
|
logger.debug(f"Number of tools being added: Internal - {len(internal_tools)} | External - {len(external_tools)}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
agent = Agent(
|
||
|
|
name=agent_config["name"],
|
||
|
|
type=agent_config.get("type", "default"),
|
||
|
|
instructions=agent_config["instructions"],
|
||
|
|
description=agent_config.get("description", ""),
|
||
|
|
internal_tools=internal_tools,
|
||
|
|
external_tools=external_tools,
|
||
|
|
candidate_parent_functions=candidate_parent_functions,
|
||
|
|
child_functions=child_functions,
|
||
|
|
model=agent_config["model"],
|
||
|
|
respond_to_user=agent_config.get("respond_to_user", False),
|
||
|
|
history=history,
|
||
|
|
children_names=agent_config.get("connectedAgents", []),
|
||
|
|
most_recent_parent=None
|
||
|
|
)
|
||
|
|
|
||
|
|
agents.append(agent)
|
||
|
|
logger.debug(f"Successfully created agent: {agent_config['name']}")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to create agent {agent_config['name']}: {str(e)}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
# Adding most recent parents to agents
|
||
|
|
for agent in agents:
|
||
|
|
most_recent_parent = None
|
||
|
|
this_agent_data = get_agent_data_by_name(agent.name, agent_data)
|
||
|
|
if this_agent_data:
|
||
|
|
most_recent_parent_name = this_agent_data.get("most_recent_parent_name", "")
|
||
|
|
if most_recent_parent_name:
|
||
|
|
most_recent_parent = get_agent_by_name(most_recent_parent_name, agents) if most_recent_parent_name else None
|
||
|
|
if most_recent_parent:
|
||
|
|
agent.most_recent_parent = most_recent_parent
|
||
|
|
|
||
|
|
# Adding children agents to parent agents
|
||
|
|
logger.info("Adding children agents to parent agents")
|
||
|
|
for agent in agents:
|
||
|
|
agent.children = {agent_.name: agent_ for agent_ in agents if agent_.name in agent.children_names}
|
||
|
|
|
||
|
|
# Generate transfer functions for transferring to children agents
|
||
|
|
logger.info("Generating transfer functions for transferring to children agents")
|
||
|
|
transfer_functions = {
|
||
|
|
agent.name: create_transfer_function_to_agent(agent)
|
||
|
|
for agent in agents
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add transfer functions for parents to transfer to children
|
||
|
|
logger.info("Adding transfer functions for parents to transfer to children")
|
||
|
|
for agent in agents:
|
||
|
|
for child in agent.children.values():
|
||
|
|
agent.child_functions[child.name] = transfer_functions[child.name]
|
||
|
|
|
||
|
|
# Add transfer-related instructions to parent agents
|
||
|
|
logger.info("Adding child transfer-related instructions to parent agents")
|
||
|
|
for agent in agents:
|
||
|
|
if agent.children:
|
||
|
|
agent = add_transfer_instructions_to_parent_agents(agent, agent.children, transfer_functions)
|
||
|
|
|
||
|
|
# Generate and append duplicate transfer functions for children to transfer to parent agents
|
||
|
|
logger.info("Generating duplicate transfer functions for children to transfer to parent agents")
|
||
|
|
for agent in agents:
|
||
|
|
for child in agent.children.values():
|
||
|
|
func = create_transfer_function_to_parent_agent(
|
||
|
|
parent_agent=agent,
|
||
|
|
children_aware_of_parent=children_aware_of_parent,
|
||
|
|
transfer_functions=transfer_functions
|
||
|
|
)
|
||
|
|
child.candidate_parent_functions[agent.name] = func
|
||
|
|
|
||
|
|
for agent in agents:
|
||
|
|
if agent.candidate_parent_functions:
|
||
|
|
agent = add_transfer_instructions_to_child_agents(
|
||
|
|
child=agent,
|
||
|
|
children_aware_of_parent=children_aware_of_parent
|
||
|
|
)
|
||
|
|
|
||
|
|
for agent in agents:
|
||
|
|
if agent.most_recent_parent:
|
||
|
|
assert agent.most_recent_parent.name in agent.candidate_parent_functions, f"Most recent parent {agent.most_recent_parent.name} not found in candidate parent functions for agent {agent.name}"
|
||
|
|
agent.parent_function = agent.candidate_parent_functions[agent.most_recent_parent.name]
|
||
|
|
|
||
|
|
return agents
|
||
|
|
|
||
|
|
def check_request_validity(messages, agent_configs, tool_configs, prompt_configs, max_overall_turns):
|
||
|
|
|
||
|
|
error_msg = ""
|
||
|
|
error_type = ErrorType.ESCALATE.value
|
||
|
|
|
||
|
|
# Limits checks
|
||
|
|
external_messages_count = sum(1 for msg in messages if msg.get("response_type") == "external")
|
||
|
|
if external_messages_count >= max_overall_turns:
|
||
|
|
error_msg = f"Max overall turns reached: {max_overall_turns}"
|
||
|
|
|
||
|
|
# Empty checks
|
||
|
|
if not messages:
|
||
|
|
error_msg = "Messages list is empty"
|
||
|
|
|
||
|
|
# Empty checks --> Fatal
|
||
|
|
if not agent_configs:
|
||
|
|
error_msg = "Agent configs list is empty"
|
||
|
|
error_type = ErrorType.FATAL.value
|
||
|
|
|
||
|
|
# Type checks --> Fatal
|
||
|
|
for arg in [messages, agent_configs, tool_configs, prompt_configs]:
|
||
|
|
if not isinstance(arg, list):
|
||
|
|
error_msg = f"{arg} is not a list"
|
||
|
|
error_type = ErrorType.FATAL.value
|
||
|
|
|
||
|
|
# Post processing agent, guardrails and escalation agent check - there should be at max one agent with type "post_processing_agent", "guardrails_agent" and "escalation_agent" respectively --> Fatal
|
||
|
|
post_processing_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.POST_PROCESSING.value)
|
||
|
|
guardrails_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.GUARDRAILS.value)
|
||
|
|
escalation_agent_count = sum(1 for ac in agent_configs if ac.get("type", "") == AgentRole.ESCALATION.value)
|
||
|
|
if post_processing_agent_count > 1 or guardrails_agent_count > 1 or escalation_agent_count > 1:
|
||
|
|
error_msg = "Invalid post processing agent or guardrails agent count - expected at most 1"
|
||
|
|
error_type = ErrorType.FATAL.value
|
||
|
|
|
||
|
|
# All agent config should have: name, instructions, model --> Fatal
|
||
|
|
for agent_config in agent_configs:
|
||
|
|
if not all(key in agent_config for key in ["name", "instructions", "model"]):
|
||
|
|
missing_keys = [key for key in ["name", "instructions", "tools", "model"] if key not in agent_config]
|
||
|
|
error_msg = f"Invalid agent config - missing keys: {missing_keys}"
|
||
|
|
error_type = ErrorType.FATAL.value
|
||
|
|
|
||
|
|
# Check for cycles in the agent config graph. Raise error if cycle is found, along with the agents involved in the cycle.
|
||
|
|
def find_cycles(agent_name, agent_configs, visited=None, path=None):
|
||
|
|
if visited is None:
|
||
|
|
visited = set()
|
||
|
|
if path is None:
|
||
|
|
path = []
|
||
|
|
|
||
|
|
visited.add(agent_name)
|
||
|
|
path.append(agent_name)
|
||
|
|
|
||
|
|
agent_config = get_agent_config_by_name(agent_name, agent_configs)
|
||
|
|
if not agent_config:
|
||
|
|
return None
|
||
|
|
|
||
|
|
for child_name in agent_config.get("connectedAgents", []):
|
||
|
|
if child_name in path:
|
||
|
|
cycle = path[path.index(child_name):]
|
||
|
|
cycle.append(child_name)
|
||
|
|
return cycle
|
||
|
|
|
||
|
|
if child_name not in visited:
|
||
|
|
cycle = find_cycles(child_name, agent_configs, visited, path)
|
||
|
|
if cycle:
|
||
|
|
return cycle
|
||
|
|
|
||
|
|
path.pop()
|
||
|
|
return None
|
||
|
|
|
||
|
|
for agent_config in agent_configs:
|
||
|
|
if agent_config.get("name") in agent_config.get("connectedAgents", []):
|
||
|
|
error_msg = f"Cycle detected in agent config graph - agent {agent_config.get('name')} is connected to itself"
|
||
|
|
|
||
|
|
cycle = find_cycles(agent_config.get("name"), agent_configs)
|
||
|
|
if cycle:
|
||
|
|
cycle_str = " -> ".join(cycle)
|
||
|
|
error_msg = f"Cycle detected in agent config graph: {cycle_str}"
|
||
|
|
|
||
|
|
return error_msg, error_type
|
||
|
|
|
||
|
|
def handle_error(error_tool_call, error_msg, return_diff_messages, messages, turn_messages, state, tokens_used):
|
||
|
|
resp_messages = turn_messages if return_diff_messages else messages + turn_messages
|
||
|
|
resp_messages.extend([create_error_tool_call(error_msg)])
|
||
|
|
if error_tool_call:
|
||
|
|
return resp_messages, tokens_used, state
|
||
|
|
else:
|
||
|
|
raise ValueError(error_msg)
|
||
|
|
|
||
|
|
def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={}, localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False, children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[], error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4, escalate_errors=True, max_overall_turns=10):
|
||
|
|
|
||
|
|
logger.info("Running stateless turn")
|
||
|
|
turn_messages = []
|
||
|
|
tokens_used = {}
|
||
|
|
messages = order_messages(messages)
|
||
|
|
tool_configs = tool_configs + additional_tool_configs
|
||
|
|
|
||
|
|
validation_error_msg, validation_error_type = check_request_validity(
|
||
|
|
messages=messages,
|
||
|
|
agent_configs=agent_configs,
|
||
|
|
tool_configs=tool_configs,
|
||
|
|
prompt_configs=prompt_configs,
|
||
|
|
max_overall_turns=max_overall_turns
|
||
|
|
)
|
||
|
|
|
||
|
|
if validation_error_msg and validation_error_type == ErrorType.FATAL.value:
|
||
|
|
logger.error(validation_error_msg)
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg=validation_error_msg,
|
||
|
|
return_diff_messages=return_diff_messages,
|
||
|
|
messages=messages,
|
||
|
|
turn_messages=turn_messages,
|
||
|
|
state=state,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
|
||
|
|
post_processing_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.POST_PROCESSING.value)
|
||
|
|
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
|
||
|
|
|
||
|
|
latest_assistant_msg = get_latest_assistant_msg(messages)
|
||
|
|
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
|
||
|
|
msg_type = latest_non_assistant_msgs[-1]["role"]
|
||
|
|
|
||
|
|
last_agent_name = get_last_agent_name(
|
||
|
|
state=state,
|
||
|
|
agent_configs=agent_configs,
|
||
|
|
start_agent_name=start_agent_name,
|
||
|
|
msg_type=msg_type,
|
||
|
|
latest_assistant_msg=latest_assistant_msg,
|
||
|
|
start_turn_with_start_agent=start_turn_with_start_agent
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("Localizing message history")
|
||
|
|
agent_data = state.get("agent_data", [])
|
||
|
|
if msg_type == "user":
|
||
|
|
messages = reset_current_turn(messages)
|
||
|
|
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
|
||
|
|
agent_data = clean_up_history(agent_data)
|
||
|
|
agent_data = add_recent_messages_to_history(
|
||
|
|
recent_messages=latest_non_assistant_msgs,
|
||
|
|
last_agent_name=last_agent_name,
|
||
|
|
agent_data=agent_data,
|
||
|
|
messages=messages,
|
||
|
|
parent_has_child_history=parent_has_child_history
|
||
|
|
)
|
||
|
|
state["agent_data"] = agent_data
|
||
|
|
|
||
|
|
logger.info("Initializing agents")
|
||
|
|
all_agents = get_agents(
|
||
|
|
agent_configs=agent_configs,
|
||
|
|
tool_configs=tool_configs,
|
||
|
|
available_tool_mappings=available_tool_mappings,
|
||
|
|
agent_data=state.get("agent_data", []),
|
||
|
|
localize_history=localize_history,
|
||
|
|
start_turn_with_start_agent=start_turn_with_start_agent,
|
||
|
|
children_aware_of_parent=children_aware_of_parent,
|
||
|
|
)
|
||
|
|
if not all_agents:
|
||
|
|
logger.error("No agents initialized")
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg="No agents initialized"
|
||
|
|
)
|
||
|
|
|
||
|
|
error_escalation_agent = deepcopy(get_agent_by_type(all_agents, AgentRole.ESCALATION.value))
|
||
|
|
if not error_escalation_agent:
|
||
|
|
logger.error("Escalation agent not found")
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg="Escalation agent not found",
|
||
|
|
return_diff_messages=return_diff_messages,
|
||
|
|
messages=messages,
|
||
|
|
turn_messages=turn_messages,
|
||
|
|
state=state,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
|
||
|
|
error_escalation_agent = clear_agent_fields(error_escalation_agent)
|
||
|
|
error_escalation_agent = add_error_escalation_instructions(error_escalation_agent)
|
||
|
|
|
||
|
|
logger.info(f"Initialized {len(all_agents)} agents")
|
||
|
|
|
||
|
|
logger.debug("Getting last agent")
|
||
|
|
last_agent = get_agent_by_name(last_agent_name, all_agents)
|
||
|
|
|
||
|
|
if not last_agent:
|
||
|
|
logger.error("Last agent not found")
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg="Last agent not found",
|
||
|
|
return_diff_messages=return_diff_messages,
|
||
|
|
messages=messages,
|
||
|
|
state=state
|
||
|
|
)
|
||
|
|
|
||
|
|
external_tools = get_external_tools(tool_configs)
|
||
|
|
logger.info(f"Found {len(external_tools)} external tools")
|
||
|
|
|
||
|
|
logger.debug("Initializing Swarm client")
|
||
|
|
client = Swarm()
|
||
|
|
|
||
|
|
if not validation_error_msg:
|
||
|
|
response = client.run(
|
||
|
|
agent=last_agent,
|
||
|
|
messages=messages,
|
||
|
|
execute_tools=True,
|
||
|
|
external_tools=external_tools,
|
||
|
|
localize_history=localize_history,
|
||
|
|
parent_has_child_history=parent_has_child_history,
|
||
|
|
max_messages_per_turn=max_messages_per_turn,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
tokens_used = response.tokens_used
|
||
|
|
last_agent = response.agent
|
||
|
|
response.messages = order_messages(response.messages)
|
||
|
|
turn_messages.extend(response.messages)
|
||
|
|
logger.info(f"Completed run of agent: {last_agent.name}")
|
||
|
|
|
||
|
|
if validation_error_msg and validation_error_type == ErrorType.ESCALATE.value or response.error_msg:
|
||
|
|
logger.info(f"Error raised in turn: {response.error_msg}")
|
||
|
|
response_sender_agent_name = response.agent.name
|
||
|
|
if escalate_errors and response_sender_agent_name != error_escalation_agent.name:
|
||
|
|
response = client.run(
|
||
|
|
agent=error_escalation_agent,
|
||
|
|
messages=[],
|
||
|
|
execute_tools=True,
|
||
|
|
external_tools=external_tools,
|
||
|
|
localize_history=False,
|
||
|
|
parent_has_child_history=False,
|
||
|
|
max_messages_per_turn=max_messages_per_error_escalation_turn,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
tokens_used = response.tokens_used
|
||
|
|
last_agent = response.agent
|
||
|
|
response.messages = order_messages(response.messages)
|
||
|
|
turn_messages.extend(response.messages)
|
||
|
|
logger.info(f"Completed run of escalation agent: {error_escalation_agent.name}")
|
||
|
|
|
||
|
|
if response.error_msg:
|
||
|
|
logger.info(f"Error raised in escalation turn: {response.error_msg}")
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg=response.error_msg,
|
||
|
|
return_diff_messages=return_diff_messages,
|
||
|
|
messages=messages,
|
||
|
|
turn_messages=turn_messages,
|
||
|
|
state=state,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
logger.info(f"Error raised in turn: {response.error_msg}")
|
||
|
|
return handle_error(
|
||
|
|
error_tool_call=error_tool_call,
|
||
|
|
error_msg=response.error_msg,
|
||
|
|
return_diff_messages=return_diff_messages,
|
||
|
|
messages=messages,
|
||
|
|
turn_messages=turn_messages,
|
||
|
|
state=state,
|
||
|
|
tokens_used=tokens_used
|
||
|
|
)
|
||
|
|
|
||
|
|
if post_processing_agent_config:
|
||
|
|
response = post_process_response(
|
||
|
|
messages=turn_messages,
|
||
|
|
post_processing_agent_name=post_processing_agent_config.get("name", "Post Processing agent"),
|
||
|
|
post_process_instructions=post_processing_agent_config.get("instructions", ""),
|
||
|
|
style_prompt=get_prompt_by_type(prompt_configs, PromptType.STYLE.value),
|
||
|
|
context='',
|
||
|
|
model=post_processing_agent_config.get("model", "gpt-4o"),
|
||
|
|
tokens_used=tokens_used,
|
||
|
|
last_agent=last_agent
|
||
|
|
)
|
||
|
|
tokens_used = response.tokens_used
|
||
|
|
response.messages = order_messages(response.messages)
|
||
|
|
turn_messages.extend(response.messages)
|
||
|
|
logger.info("Response post-processed")
|
||
|
|
|
||
|
|
if guardrails_agent_config:
|
||
|
|
logger.info("Guardrails agent not implemented (ignoring)")
|
||
|
|
pass
|
||
|
|
|
||
|
|
if not state or not state.get("last_agent_name"):
|
||
|
|
logger.error("State is empty or last agent name is not set")
|
||
|
|
raise ValueError("State is empty or last agent name is not set")
|
||
|
|
|
||
|
|
response.messages = turn_messages if return_diff_messages else messages + turn_messages
|
||
|
|
response.tokens_used = tokens_used
|
||
|
|
new_state = construct_state_from_response(response, all_agents)
|
||
|
|
return response.messages, response.tokens_used, new_state
|