Merge branch 'basic_ability' into 'mgx_ops'

change routing & add team info for quick question

See merge request pub/MetaGPT!241
This commit is contained in:
林义章 2024-07-19 06:08:14 +00:00
commit 43f316a4d4
5 changed files with 145 additions and 15 deletions

View file

@ -72,7 +72,8 @@ Help check if there are any formatting issues with the JSON data? If so, please
QUICK_THINK_PROMPT = """
Decide if the latest user message is a quick question.
Quick questions include common-sense, logical, math questions, greetings, or casual chat that you can answer directly, excluding software development tasks.
Respond with "#YES#, (then start your actual response to the question...)" if so, otherwise, simply respond with "#NO#".
Your response:
Quick questions include common-sense, logical, math, multiple-choice questions, greetings, or casual chat that you can answer directly.
Questions about you or your team info are also quick questions.
Programming or software development tasks are NOT quick questions except for filling a single function or class.
Respond with YES if so, otherwise, NO. Your response:
"""

View file

@ -24,6 +24,13 @@ Note:
7. If the requirement is writing a TRD and software framework, you should assign it to Architect. When publishing message to Architect, you should directly copy the full original user requirement.
"""
QUICK_THINK_SYSTEM_PROMPT = """
{role_info}
Your team member:
{team_info}
However, you MUST respond to the user message by yourself directly, DON'T ask your team members.
"""
FINISH_CURRENT_TASK_CMD = """
```json
[

View file

@ -236,20 +236,21 @@ class RoleZero(Role):
return rsp # return output from the last action
async def _quick_think(self) -> Message:
msg = self.rc.news[-1]
rsp_msg = None
if msg.cause_by != any_to_str(UserRequirement):
if self.rc.news[-1].cause_by != any_to_str(UserRequirement):
# Agents themselves won't generate quick questions, use this rule to reduce extra llm calls
return rsp_msg
context = self.llm.format_msg(self.get_memories(k=4) + [UserMessage(content=QUICK_THINK_PROMPT)])
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
rsp = await self.llm.aask(context)
# routing
memory = self.get_memories(k=4)
context = self.llm.format_msg(memory + [UserMessage(content=QUICK_THINK_PROMPT)])
rsp = await self.llm.aask(context)
pattern = r"#YES#,? ?"
if re.search(pattern, rsp):
answer = re.sub(pattern, "", rsp).strip()
if "yes" in rsp.lower():
# llm call with the original context
async with ThoughtReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "quick"})
answer = await self.llm.aask(self.llm.format_msg(memory))
self.rc.memory.add(AIMessage(content=answer, cause_by=RunCommand))
await self.reply_to_human(content=answer)
rsp_msg = AIMessage(

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from metagpt.actions.di.run_command import RunCommand
from metagpt.prompts.di.team_leader import (
FINISH_CURRENT_TASK_CMD,
QUICK_THINK_SYSTEM_PROMPT,
SYSTEM_PROMPT,
TL_INSTRUCTION,
)
@ -16,6 +17,7 @@ from metagpt.tools.tool_registry import register_tool
class TeamLeader(RoleZero):
name: str = "Mike"
profile: str = "Team Leader"
goal: str = "Manage a team to assist users"
system_msg: list[str] = [SYSTEM_PROMPT]
# TeamLeader only reacts once each time, but may encounter errors or need to ask human, thus allowing 2 more turns
@ -33,16 +35,26 @@ class TeamLeader(RoleZero):
}
)
def set_instruction(self):
def _get_team_info(self) -> str:
if not self.rc.env:
return ""
team_info = ""
for role in self.rc.env.roles.values():
# if role.profile == "Team Leader":
# continue
team_info += f"{role.name}: {role.profile}, {role.goal}\n"
self.instruction = TL_INSTRUCTION.format(team_info=team_info)
return team_info
async def _quick_think(self) -> Message:
# insert team info for quick question
self.llm.system_prompt = QUICK_THINK_SYSTEM_PROMPT.format(
role_info=super()._get_prefix(),
team_info=self._get_team_info(),
)
return await super()._quick_think()
async def _think(self) -> bool:
self.set_instruction()
self.instruction = TL_INSTRUCTION.format(team_info=self._get_team_info())
return await super()._think()
def publish_message(self, msg: Message, send_to="no one"):

View file

@ -0,0 +1,109 @@
import asyncio
from metagpt.environment.mgx.mgx_env import MGXEnv
from metagpt.roles import Architect, ProductManager, ProjectManager
from metagpt.roles.di.data_analyst import DataAnalyst
from metagpt.roles.di.engineer2 import Engineer2
from metagpt.roles.di.team_leader import TeamLeader
from metagpt.schema import Message
NORMAL_QUESTION = [
"create a 2048 game",
"write a snake game",
"Write a 2048 game using JavaScript without using any frameworks, user can play with keyboard.",
"print statistic summary of sklearn iris dataset",
"Run data analysis on sklearn Wine recognition dataset, and train a model to predict wine class (20% as validation), and show validation accuracy.",
"""
Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*
""",
"""
Write a fix for this issue: https://github.com/langchain-ai/langchain/issues/20453,
you can fix it on this repo https://github.com/garylin2099/langchain,
checkout a branch named test-fix, commit your changes, push, and create a PR to the master branch of https://github.com/iorisa/langchain
""",
]
QUICK_QUESTION = [
# general knowledge qa, logical, math
"""Who is the first man landing on Moon""",
"""In DNA adenine normally pairs with: A. cytosine. B. guanine. C. thymine. D. uracil. Answer:""",
"""________________ occur(s) where there is no prior history of exchange and no future exchanges are expected between a buyer and seller. A. Relationship marketing. B. Service mix. C. Market exchanges. D. Service failure. Answer:""",
"""Within American politics, the power to accord official recognition to other countries belongs to A. the Senate. B. the president. C. the Secretary of State. D. the chairman of the Joint Chiefs. Answer:""",
"""Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.""",
"""True or false? Statement 1 | A ring homomorphism is one to one if and only if the kernel is {{0}},. Statement 2 | Q is an ideal in R""",
"""Jean has 30 lollipops. Jean eats 2 of the lollipops. With the remaining lollipops, Jean wants to package 2 lollipops in one bag. How many bags can Jean fill?""",
"""Alisa biked 12 miles per hour for 4.5 hours. Stanley biked at 10 miles per hour for 2.5 hours. How many miles did Alisa and Stanley bike in total?""",
# function filling (humaneval)
"""
def has_close_elements(numbers: List[float], threshold: float) -> bool:
''' Check if in given list of numbers, are any two numbers closer to each other than
given threshold.
>>> has_close_elements([1.0, 2.0, 3.0], 0.5)
False
>>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)
True
'''
""",
"""
def is_palindrome(string: str) -> bool:
''' Test if given string is a palindrome '''
return string == string[::-1]
def make_palindrome(string: str) -> str:
''' Find the shortest palindrome that begins with a supplied string.
Algorithm idea is simple:
- Find the longest postfix of supplied string that is a palindrome.
- Append to the end of the string reverse of a string prefix that comes before the palindromic suffix.
>>> make_palindrome('')
''
>>> make_palindrome('cat')
'catac'
>>> make_palindrome('cata')
'catac'
'''
""",
# casual chat
"""What's your name?""",
"Who are you",
"What can you do",
"Hi",
"1+1",
]
async def test_routing_acc():
role = TeamLeader()
env = MGXEnv()
env.add_roles(
[
role,
ProductManager(),
Architect(),
ProjectManager(),
Engineer2(),
DataAnalyst(),
]
)
for q in QUICK_QUESTION:
msg = Message(content=q)
role.put_message(msg)
await role._observe()
rsp = await role._quick_think()
role.rc.memory.clear()
assert rsp
for q in NORMAL_QUESTION:
msg = Message(content=q)
role.put_message(msg)
await role._observe()
rsp = await role._quick_think()
role.rc.memory.clear()
assert not rsp
if __name__ == "__main__":
asyncio.run(test_routing_acc())