modify add action to set action

This commit is contained in:
geekan 2024-01-10 17:54:13 +08:00
parent 07d34bda7a
commit 91e6564586
19 changed files with 34 additions and 33 deletions

View file

@ -61,7 +61,7 @@ class AgentCreator(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([CreateAgent])
self.set_actions([CreateAgent])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")

View file

@ -57,7 +57,7 @@ class SimpleCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([SimpleWriteCode])
self.set_actions([SimpleWriteCode])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
@ -76,7 +76,7 @@ class RunnableCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([SimpleWriteCode, SimpleRunCode])
self.set_actions([SimpleWriteCode, SimpleRunCode])
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _act(self) -> Message:

View file

@ -46,7 +46,7 @@ class SimpleCoder(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._watch([UserRequirement])
self.add_actions([SimpleWriteCode])
self.set_actions([SimpleWriteCode])
class SimpleWriteTest(Action):
@ -75,7 +75,7 @@ class SimpleTester(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([SimpleWriteTest])
self.set_actions([SimpleWriteTest])
# self._watch([SimpleWriteCode])
self._watch([SimpleWriteCode, SimpleWriteReview]) # feel free to try this too
@ -114,7 +114,7 @@ class SimpleReviewer(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([SimpleWriteReview])
self.set_actions([SimpleWriteReview])
self._watch([SimpleWriteTest])

View file

@ -49,7 +49,7 @@ class Debator(Role):
def __init__(self, **data: Any):
super().__init__(**data)
self.add_actions([SpeakAloud])
self.set_actions([SpeakAloud])
self._watch([UserRequirement, SpeakAloud])
async def _observe(self) -> int:

View file

@ -33,7 +33,7 @@ class Architect(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
# Initialize actions specific to the Architect role
self.add_actions([WriteDesign])
self.set_actions([WriteDesign])
# Set events or actions the Architect should watch or be aware of
self._watch({WritePRD})

View file

@ -84,7 +84,7 @@ class Engineer(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.add_actions([WriteCode])
self.set_actions([WriteCode])
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
self.code_todos = []
self.summarize_todos = []

View file

@ -60,7 +60,7 @@ class InvoiceOCRAssistant(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([InvoiceOCR])
self.set_actions([InvoiceOCR])
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _act(self) -> Message:
@ -82,10 +82,10 @@ class InvoiceOCRAssistant(Role):
resp = await todo.run(file_path)
if len(resp) == 1:
# Single file support for questioning based on OCR recognition results
self.add_actions([GenerateTable, ReplyQuestion])
self.set_actions([GenerateTable, ReplyQuestion])
self.orc_data = resp[0]
else:
self.add_actions([GenerateTable])
self.set_actions([GenerateTable])
self.set_todo(None)
content = INVOICE_OCR_SUCCESS

View file

@ -33,7 +33,7 @@ class ProductManager(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.add_actions([PrepareDocuments, WritePRD])
self.set_actions([PrepareDocuments, WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)

View file

@ -33,5 +33,5 @@ class ProjectManager(Role):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.add_actions([WriteTasks])
self.set_actions([WriteTasks])
self._watch([WriteDesign])

View file

@ -44,7 +44,7 @@ class QaEngineer(Role):
# FIXME: a bit hack here, only init one action to circumvent _think() logic,
# will overwrite _think() in future updates
self.add_actions([WriteTest])
self.set_actions([WriteTest])
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
self.test_round = 0

View file

@ -34,7 +34,7 @@ class Researcher(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions(
self.set_actions(
[CollectLinks(name=self.name), WebBrowseAndSummarize(name=self.name), ConductResearch(name=self.name)]
)
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)

View file

@ -222,16 +222,17 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
def _init_action_system_message(self, action: Action):
action.set_prefix(self._get_prefix())
def add_action(self, action: Action):
def set_action(self, action: Action):
"""Add action to the role."""
self.add_actions([action])
self.set_actions([action])
def add_actions(self, actions: list[Union[Action, Type[Action]]]):
def set_actions(self, actions: list[Union[Action, Type[Action]]]):
"""Add actions to the role.
Args:
actions: list of Action classes or instances
"""
self._reset()
for action in actions:
if not isinstance(action, Action):
i = action(name="", llm=self.llm)

View file

@ -38,5 +38,5 @@ class Sales(Role):
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
else:
action = SearchAndSummarize()
self.add_actions([action])
self.set_actions([action])
self._watch([UserRequirement])

View file

@ -48,12 +48,12 @@ class Searcher(Role):
engine (SearchEngineType): The type of search engine to use.
"""
super().__init__(**kwargs)
self.add_actions([SearchAndSummarize(engine=self.engine)])
self.set_actions([SearchAndSummarize(engine=self.engine)])
def set_search_func(self, search_func):
"""Sets a custom search function for the searcher."""
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func)
self.add_actions([action])
self.set_actions([action])
async def _act_sp(self) -> Message:
"""Performs the search action in a single process."""

View file

@ -49,7 +49,7 @@ class SkAgent(Role):
def __init__(self, **data: Any) -> None:
"""Initializes the Engineer role with given attributes."""
super().__init__(**data)
self.add_actions([ExecuteTask()])
self.set_actions([ExecuteTask()])
self._watch([UserRequirement])
self.kernel = make_sk_kernel()

View file

@ -47,7 +47,7 @@ class Teacher(Role):
for topic in TeachingPlanBlock.TOPICS:
act = WriteTeachingPlanPart(i_context=self.rc.news[0].content, topic=topic, llm=self.llm)
actions.append(act)
self.add_actions(actions)
self.set_actions(actions)
if self.rc.todo is None:
self._set_state(0)

View file

@ -40,7 +40,7 @@ class TutorialAssistant(Role):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_actions([WriteDirectory(language=self.language)])
self.set_actions([WriteDirectory(language=self.language)])
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _handle_directory(self, titles: Dict) -> Message:
@ -63,7 +63,7 @@ class TutorialAssistant(Role):
directory += f"- {key}\n"
for second_dir in first_dir[key]:
directory += f" - {second_dir}\n"
self.add_actions(actions)
self.set_actions(actions)
async def _act(self) -> Message:
"""Perform an action as determined by the role.

View file

@ -67,7 +67,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.add_actions([ActionPass])
self.set_actions([ActionPass])
self._watch([UserRequirement])
@ -79,7 +79,7 @@ class RoleB(Role):
def __init__(self, **kwargs):
super(RoleB, self).__init__(**kwargs)
self.add_actions([ActionOK, ActionRaise])
self.set_actions([ActionOK, ActionRaise])
self._watch([ActionPass])
self.rc.react_mode = RoleReactMode.BY_ORDER
@ -92,7 +92,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.add_actions([ActionOK, ActionRaise])
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

View file

@ -33,7 +33,7 @@ class MockAction(Action):
class MockRole(Role):
def __init__(self, name="", profile="", goal="", constraints="", desc=""):
super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
self.add_actions([MockAction()])
self.set_actions([MockAction()])
def test_basic():
@ -111,7 +111,7 @@ async def test_send_to():
def test_init_action():
role = Role()
role.add_actions([MockAction, MockAction])
role.set_actions([MockAction, MockAction])
assert len(role.actions) == 2
@ -127,7 +127,7 @@ async def test_recover():
role.publish_message(None)
role.llm = mock_llm
role.add_actions([MockAction, MockAction])
role.set_actions([MockAction, MockAction])
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
@ -144,7 +144,7 @@ async def test_think_act():
mock_llm.aask.side_effect = ["ok"]
role = Role()
role.add_actions([MockAction])
role.set_actions([MockAction])
await role.think()
role.rc.memory.add(Message("run"))
assert len(role.get_memories()) == 1