mirror of
https://github.com/rowboatlabs/rowboat.git
synced 2026-06-09 19:45:17 +02:00
Add support for greeting message in agents
This commit is contained in:
parent
80e410888b
commit
d15eddc951
3 changed files with 108 additions and 36 deletions
|
|
@ -14,6 +14,7 @@ from .helpers.state import add_recent_messages_to_history, construct_state_from_
|
|||
from .helpers.instructions import add_transfer_instructions_to_child_agents, add_transfer_instructions_to_parent_agents, add_rag_instructions_to_agent, add_error_escalation_instructions, get_universal_system_message, add_universal_system_message_to_agent
|
||||
from .helpers.control import get_latest_assistant_msg, get_latest_non_assistant_messages, get_last_agent_name
|
||||
from src.swarm.types import Response
|
||||
from datetime import datetime
|
||||
|
||||
from src.utils.common import common_logger
|
||||
logger = common_logger
|
||||
|
|
@ -283,9 +284,16 @@ def handle_error(error_tool_call, error_msg, return_diff_messages, messages, tur
|
|||
return resp_messages, tokens_used, state
|
||||
else:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def create_final_response(response, turn_messages, messages, tokens_used, all_agents, return_diff_messages):
|
||||
response.messages = turn_messages if return_diff_messages else messages + turn_messages
|
||||
response.tokens_used = tokens_used
|
||||
new_state = construct_state_from_response(response, all_agents)
|
||||
return response.messages, response.tokens_used, new_state
|
||||
|
||||
def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_tool_mappings={}, localize_history=True, return_diff_messages=True, prompt_configs=[], start_turn_with_start_agent=False, children_aware_of_parent=False, parent_has_child_history=True, state={}, additional_tool_configs=[], error_tool_call=True, max_messages_per_turn=10, max_messages_per_error_escalation_turn=4, escalate_errors=True, max_overall_turns=10):
|
||||
|
||||
|
||||
greeting_turn = True if not any(msg.get("role") != "system" for msg in messages) else False
|
||||
logger.info("Running stateless turn")
|
||||
turn_messages = []
|
||||
tokens_used = {}
|
||||
|
|
@ -310,40 +318,42 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
|
|||
turn_messages=turn_messages,
|
||||
state=state,
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
)
|
||||
|
||||
post_processing_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.POST_PROCESSING.value)
|
||||
guardrails_agent_config, agent_configs = pop_agent_config_by_type(agent_configs, AgentRole.GUARDRAILS.value)
|
||||
|
||||
latest_assistant_msg = get_latest_assistant_msg(messages)
|
||||
universal_sys_msg = get_universal_system_message(messages)
|
||||
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
|
||||
msg_type = latest_non_assistant_msgs[-1]["role"]
|
||||
|
||||
last_agent_name = get_last_agent_name(
|
||||
state=state,
|
||||
agent_configs=agent_configs,
|
||||
start_agent_name=start_agent_name,
|
||||
msg_type=msg_type,
|
||||
latest_assistant_msg=latest_assistant_msg,
|
||||
start_turn_with_start_agent=start_turn_with_start_agent
|
||||
)
|
||||
|
||||
logger.info("Localizing message history")
|
||||
agent_data = state.get("agent_data", [])
|
||||
if msg_type == "user":
|
||||
messages = reset_current_turn(messages)
|
||||
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
|
||||
agent_data = clean_up_history(agent_data)
|
||||
agent_data = add_recent_messages_to_history(
|
||||
recent_messages=latest_non_assistant_msgs,
|
||||
last_agent_name=last_agent_name,
|
||||
agent_data=agent_data,
|
||||
messages=messages,
|
||||
parent_has_child_history=parent_has_child_history
|
||||
)
|
||||
state["agent_data"] = agent_data
|
||||
universal_sys_msg = ""
|
||||
|
||||
if not greeting_turn:
|
||||
latest_assistant_msg = get_latest_assistant_msg(messages)
|
||||
universal_sys_msg = get_universal_system_message(messages)
|
||||
latest_non_assistant_msgs = get_latest_non_assistant_messages(messages)
|
||||
msg_type = latest_non_assistant_msgs[-1]["role"]
|
||||
|
||||
last_agent_name = get_last_agent_name(
|
||||
state=state,
|
||||
agent_configs=agent_configs,
|
||||
start_agent_name=start_agent_name,
|
||||
msg_type=msg_type,
|
||||
latest_assistant_msg=latest_assistant_msg,
|
||||
start_turn_with_start_agent=start_turn_with_start_agent
|
||||
)
|
||||
|
||||
logger.info("Localizing message history")
|
||||
if msg_type == "user":
|
||||
messages = reset_current_turn(messages)
|
||||
agent_data = reset_current_turn_agent_history(agent_data, [last_agent_name])
|
||||
agent_data = clean_up_history(agent_data)
|
||||
agent_data = add_recent_messages_to_history(
|
||||
recent_messages=latest_non_assistant_msgs,
|
||||
last_agent_name=last_agent_name,
|
||||
agent_data=agent_data,
|
||||
messages=messages,
|
||||
parent_has_child_history=parent_has_child_history
|
||||
)
|
||||
|
||||
state["agent_data"] = agent_data
|
||||
logger.info("Initializing agents")
|
||||
all_agents = get_agents(
|
||||
agent_configs=agent_configs,
|
||||
|
|
@ -361,6 +371,49 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
|
|||
error_tool_call=error_tool_call,
|
||||
error_msg="No agents initialized"
|
||||
)
|
||||
|
||||
if greeting_turn:
|
||||
greeting_msg = get_prompt_by_type(prompt_configs, PromptType.GREETING.value)
|
||||
if not greeting_msg:
|
||||
logger.error("Greeting prompt not found and messages is empty")
|
||||
return handle_error(
|
||||
error_tool_call=error_tool_call,
|
||||
error_msg="Greeting prompt not found and messages is empty",
|
||||
return_diff_messages=return_diff_messages,
|
||||
messages=messages,
|
||||
turn_messages=turn_messages,
|
||||
state=state,
|
||||
tokens_used=tokens_used
|
||||
)
|
||||
|
||||
greeting_msg_internal = {
|
||||
"content": greeting_msg,
|
||||
"role": "assistant",
|
||||
"sender": start_agent_name,
|
||||
"response_type": "internal",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"current_turn": True
|
||||
}
|
||||
greeting_msg_external = deepcopy(greeting_msg_internal)
|
||||
greeting_msg_external["response_type"] = "external"
|
||||
greeting_msg_external["sender"] = greeting_msg_external["sender"] + ' >> External'
|
||||
turn_messages.extend([greeting_msg_internal, greeting_msg_external])
|
||||
|
||||
response = Response(
|
||||
messages=turn_messages,
|
||||
tokens_used={},
|
||||
agent=get_agent_by_name(start_agent_name, all_agents),
|
||||
error_msg=''
|
||||
)
|
||||
|
||||
return create_final_response(
|
||||
response=response,
|
||||
turn_messages=turn_messages,
|
||||
messages=messages,
|
||||
tokens_used=tokens_used,
|
||||
all_agents=all_agents,
|
||||
return_diff_messages=return_diff_messages
|
||||
)
|
||||
|
||||
error_escalation_agent = deepcopy(get_agent_by_type(all_agents, AgentRole.ESCALATION.value))
|
||||
if not error_escalation_agent:
|
||||
|
|
@ -498,7 +551,11 @@ def run_turn(messages, start_agent_name, agent_configs, tool_configs, available_
|
|||
logger.error("State is empty or last agent name is not set")
|
||||
raise ValueError("State is empty or last agent name is not set")
|
||||
|
||||
response.messages = turn_messages if return_diff_messages else messages + turn_messages
|
||||
response.tokens_used = tokens_used
|
||||
new_state = construct_state_from_response(response, all_agents)
|
||||
return response.messages, response.tokens_used, new_state
|
||||
return create_final_response(
|
||||
response=response,
|
||||
turn_messages=turn_messages,
|
||||
messages=messages,
|
||||
tokens_used=tokens_used,
|
||||
all_agents=all_agents,
|
||||
return_diff_messages=return_diff_messages
|
||||
)
|
||||
|
|
@ -11,8 +11,8 @@ class ControlType(Enum):
|
|||
|
||||
class PromptType(Enum):
|
||||
STYLE = "style_prompt"
|
||||
GREETING = "greeting"
|
||||
|
||||
class ErrorType(Enum):
|
||||
FATAL = "fatal"
|
||||
ESCALATE = "escalate"
|
||||
|
||||
ESCALATE = "escalate"
|
||||
|
|
@ -75,6 +75,11 @@ You are an helpful customer support assistant
|
|||
type: "style_prompt",
|
||||
prompt: "You should be empathetic and helpful.",
|
||||
},
|
||||
{
|
||||
name: "Greeting",
|
||||
type: "greeting",
|
||||
prompt: "Hello! How can I help you?"
|
||||
}
|
||||
],
|
||||
tools: [],
|
||||
},
|
||||
|
|
@ -131,6 +136,11 @@ You are an helpful customer support assistant
|
|||
"name": "Style prompt",
|
||||
"type": "style_prompt",
|
||||
"prompt": "You should be empathetic and helpful."
|
||||
},
|
||||
{
|
||||
"name": "Greeting",
|
||||
"type": "greeting",
|
||||
"prompt": "Hello! How can I help you?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
|
|
@ -259,6 +269,11 @@ You are an helpful customer support assistant
|
|||
"type": "style_prompt",
|
||||
"prompt": "---\n\nmake this more friendly. Keep it to 5-7 sentences. Use these as example references:\n\n---"
|
||||
},
|
||||
{
|
||||
"name": "Greeting",
|
||||
"type": "greeting",
|
||||
"prompt": "Hello! How can I help you?"
|
||||
},
|
||||
{
|
||||
"name": "structured_output",
|
||||
"type": "base_prompt",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue