mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
update code and fix reference
This commit is contained in:
parent
0316d4dae9
commit
ee7de61025
34 changed files with 265 additions and 214 deletions
1
examples/st_game/.gitignore
vendored
Normal file
1
examples/st_game/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
storage/test*
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
## Stanford Town Game
|
||||
|
||||
执行入口为:`python3 run_st_game.py "Host a activity"`
|
||||
### 后端服务启动
|
||||
执行入口为:`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10`
|
||||
|
||||
`idea`为用户给第一个Agent的用户心声,并通过这个心声进行传播,看最后多智能体是否达到举办、参加活动的目标。
|
||||
|
||||
### 前端服务启动
|
||||
进入`generative_agents/environment/frontend_server`,使用`~~python manage.py runserver~~`启动前端服务。
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
examples/st_game/actions/__init__.py
Normal file
3
examples/st_game/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -5,8 +5,8 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.st_action import STAction
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
|
||||
|
||||
class AgentChatSumRel(STAction):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.st_action import STAction
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
|
||||
|
||||
class DecideToTalk(STAction):
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.st_action import STAction
|
||||
from ..utils.utils import extract_first_json_dict
|
||||
from ..maze import Maze
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
from examples.st_game.utils.utils import extract_first_json_dict
|
||||
from examples.st_game.maze import Maze
|
||||
|
||||
|
||||
class GenIterChatUTT(STAction):
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ import datetime
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.st_action import STAction
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
|
||||
|
||||
class NewDecompSchedule(STAction):
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@
|
|||
# @Desc : Integration Reflect Action
|
||||
|
||||
import re
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.st_action import STAction
|
||||
from ..memory.agent_memory import BasicMemory
|
||||
from metagpt.logs import logger
|
||||
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
|
||||
|
||||
# Run GPT Prompt Focal Point method
|
||||
class AgentFocusPt(STAction):
|
||||
|
||||
|
|
@ -31,8 +32,8 @@ class AgentFocusPt(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, n: int, test_input=None) -> str:
|
||||
def create_prompt_input(role: STRole, statements, n, test_input=None):
|
||||
async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> str:
|
||||
def create_prompt_input(role: "STRole", statements, n, test_input=None):
|
||||
prompt_input = [statements, str(n)]
|
||||
return prompt_input
|
||||
|
||||
|
|
@ -77,8 +78,8 @@ class AgentInsightAndGuidance(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, n: int, test_input=None) -> str:
|
||||
def create_prompt_input(role: STRole, statements, n, test_input=None):
|
||||
async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> str:
|
||||
def create_prompt_input(role: "STRole", statements, n, test_input=None):
|
||||
prompt_input = [statements, str(n)]
|
||||
return prompt_input
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class AgentEventTriple(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, statements: str, role: STRole, verbose=False) -> str:
|
||||
async def run(self, statements: str, role: "STRole", verbose=False) -> str:
|
||||
def create_prompt_input(statements, role):
|
||||
if "(" in statements:
|
||||
statements = statements.split("(")[-1].split(")")[0]
|
||||
|
|
@ -151,8 +152,8 @@ class AgentEventPoignancy(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role: STRole, statements: str, test_input=None):
|
||||
async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role: "STRole", statements: str, test_input=None):
|
||||
prompt_input = [role._rc.scratch.name,
|
||||
role._rc.scratch.get_str_iss(),
|
||||
role._rc.scratch.name,
|
||||
|
|
@ -192,8 +193,8 @@ class AgentChatPoignancy(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role: STRole, statements, test_input=None):
|
||||
async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role: "STRole", statements, test_input=None):
|
||||
prompt_input = [role._rc.scratch.name,
|
||||
role._rc.scratch.get_str_iss(),
|
||||
role._rc.scratch.name,
|
||||
|
|
@ -232,7 +233,7 @@ class AgentPlanThoughtOnConvo(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, test_input=None, verbose=False) -> str:
|
||||
async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role, statements, test_input=None):
|
||||
prompt_input = [statements,
|
||||
role._rc.scratch.name,
|
||||
|
|
@ -268,7 +269,7 @@ class AgentMemoryOnConvo(STAction):
|
|||
def _func_fail_default_resp(self) -> str:
|
||||
pass
|
||||
|
||||
async def run(self, role: STRole, statements: str, test_input=None, verbose=False) -> str:
|
||||
async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str:
|
||||
def create_prompt_input(role, statements, test_input=None):
|
||||
prompt_input = [statements,
|
||||
role._rc.scratch.name,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import json
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..utils.const import PROMPTS_DIR
|
||||
from examples.st_game.utils.const import PROMPTS_DIR
|
||||
|
||||
|
||||
class STAction(Action):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from metagpt.schema import Message
|
||||
|
||||
from ..actions.st_action import STAction
|
||||
from examples.st_game.actions.st_action import STAction
|
||||
|
||||
|
||||
class SummarizeConv(STAction):
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ import json
|
|||
import math
|
||||
from pathlib import Path
|
||||
import networkx as nx
|
||||
from .utils.const import MAZE_ASSET_PATH
|
||||
from .utils.utils import read_csv_to_list
|
||||
from utils.const import MAZE_ASSET_PATH
|
||||
from utils.utils import read_csv_to_list
|
||||
|
||||
|
||||
class Maze:
|
||||
|
|
@ -212,13 +212,13 @@ class Maze:
|
|||
else:
|
||||
self.address_tiles[add] = set([(j, i)])
|
||||
|
||||
# Build an nx.Graph.
|
||||
grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height)
|
||||
for i in range(self.maze_height):
|
||||
for j in range(self.maze_width):
|
||||
if self.collision_maze[i][j] != 0:
|
||||
grid_graph.remove_node((i, j))
|
||||
self.nx_graph = grid_graph
|
||||
# # Build an nx.Graph.
|
||||
# grid_graph = nx.grid_2d_graph(m=self.maze_width, n=self.maze_height)
|
||||
# for i in range(self.maze_height):
|
||||
# for j in range(self.maze_width):
|
||||
# if self.collision_maze[i][j] != 0:
|
||||
# grid_graph.remove_node((i, j))
|
||||
# self.nx_graph = grid_graph
|
||||
|
||||
def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]:
|
||||
"""
|
||||
|
|
@ -389,24 +389,24 @@ class Maze:
|
|||
if event[0] == subject:
|
||||
self.tiles[tile[1]][tile[0]]["events"].remove(event)
|
||||
|
||||
def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]:
|
||||
target_coords = self.nx_graph.nodes
|
||||
min_dist = None
|
||||
closest_coordinate = None
|
||||
for target in target_coords:
|
||||
dist = math.dist(coords, target)
|
||||
if not closest_coordinate:
|
||||
min_dist = dist
|
||||
closest_coordinate = target
|
||||
else:
|
||||
if min_dist > dist:
|
||||
min_dist = dist
|
||||
closest_coordinate = target
|
||||
return closest_coordinate
|
||||
|
||||
def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]:
|
||||
if start not in self.nx_graph.nodes:
|
||||
start = self._find_closest_node(start)
|
||||
if end not in self.nx_graph.nodes:
|
||||
end = self._find_closest_node(end)
|
||||
return nx.shortest_path(self.nx_graph, start, end)
|
||||
# def _find_closest_node(self, coords: tuple[int, int]) -> tuple[int, int]:
|
||||
# target_coords = self.nx_graph.nodes
|
||||
# min_dist = None
|
||||
# closest_coordinate = None
|
||||
# for target in target_coords:
|
||||
# dist = math.dist(coords, target)
|
||||
# if not closest_coordinate:
|
||||
# min_dist = dist
|
||||
# closest_coordinate = target
|
||||
# else:
|
||||
# if min_dist > dist:
|
||||
# min_dist = dist
|
||||
# closest_coordinate = target
|
||||
# return closest_coordinate
|
||||
#
|
||||
# def find_path(self, start: tuple[int, int], end: tuple[int, int]) -> list[tuple[int, int]]:
|
||||
# if start not in self.nx_graph.nodes:
|
||||
# start = self._find_closest_node(start)
|
||||
# if end not in self.nx_graph.nodes:
|
||||
# end = self._find_closest_node(end)
|
||||
# return nx.shortest_path(self.nx_graph, start, end)
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from pydantic import Field
|
|||
from metagpt.environment import Environment
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
from .maze import Maze
|
||||
from maze import Maze
|
||||
|
||||
|
||||
class MazeEnvironment(Environment):
|
||||
|
||||
maze: Maze = Field(default=Maze)
|
||||
maze: Maze = Field(default_factory=Maze)
|
||||
|
||||
def add_role(self, role: Role):
|
||||
role.set_env(self)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : BasicMemory,AgentMemory实现
|
||||
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class BasicMemory(Message):
|
||||
|
||||
|
|
@ -93,13 +94,13 @@ class AgentMemory(Memory):
|
|||
3. kw_strength.json
|
||||
"""
|
||||
|
||||
def __init__(self, memory_saved: str):
|
||||
def __init__(self):
|
||||
"""
|
||||
AgentMemory类继承自Memory类,重写storage替代GA中id_to_node,一方面存储所有信息,一方面作为JSON转化
|
||||
index存储与不同Agent的chat信息
|
||||
@李嵩@张凯 这里的storage是List,你们需要写一个JSON转化器,将List修改为node.json一致的格式
|
||||
"""
|
||||
super.__init__()
|
||||
super(AgentMemory, self).__init__()
|
||||
self.id_to_node = dict() # TODO jiayi add
|
||||
self.storage: list[BasicMemory] = [] # 重写Stroage,存储BasicMemory所有节点
|
||||
self.event_list = [] # 存储event记忆
|
||||
|
|
@ -113,6 +114,10 @@ class AgentMemory(Memory):
|
|||
self.kw_strength_event = dict() # 关键词影响存储
|
||||
self.kw_strength_thought = dict()
|
||||
|
||||
# self.load(memory_saved)
|
||||
|
||||
def set_mem_path(self, memory_saved: str):
|
||||
self.memory_saved = memory_saved
|
||||
self.load(memory_saved)
|
||||
|
||||
def save(self, memory_saved: str):
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ from typing import Union
|
|||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
|
||||
from ..memory.agent_memory import AgentMemory, BasicMemory
|
||||
from ..utils.utils import get_embedding
|
||||
from ..roles.st_role import STRole
|
||||
from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
|
||||
from examples.st_game.utils.utils import get_embedding
|
||||
|
||||
def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str, topk: int = 4) -> list[BasicMemory]:
|
||||
|
||||
def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memory_forget: float, query: str,
|
||||
topk: int = 4) -> list[BasicMemory]:
|
||||
"""
|
||||
Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch
|
||||
逻辑:Role调用该函数,self._rc.AgentMemory,self._rc.scratch.curr_time,self._rc.scratch.memory_forget
|
||||
|
|
@ -46,25 +47,27 @@ def agent_retrieve(agent_memory: AgentMemory, curr_time: datetime.datetime, memo
|
|||
|
||||
result = top_highest_x_values(total_dict, topk)
|
||||
|
||||
return result # 返回的是一个BasicMemory列表
|
||||
return result # 返回的是一个BasicMemory列表
|
||||
|
||||
def new_agent_retrieve(strole: STRole, focus_points: list, n_count = 30):
|
||||
|
||||
def new_agent_retrieve(strole: "STRole", focus_points: list, n_count=30):
|
||||
"""
|
||||
输入为Strole,关注点列表,返回记忆数量
|
||||
输出为字典,键为focus_point,值为对应的记忆列表
|
||||
"""
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
nodes = [[i.last_accessed, i]
|
||||
for i in strole._rc.memory.event_list + strole._rc.memory.thought_list
|
||||
if "idle" not in i.embedding_key]
|
||||
for i in strole._rc.memory.event_list + strole._rc.memory.thought_list
|
||||
if "idle" not in i.embedding_key]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(strole._rc.memory, strole._rc.scratch.curr_time, strole._rc.scratch.recency_decay, focal_pt, n_count)
|
||||
for n in results:
|
||||
results = agent_retrieve(strole._rc.memory, strole._rc.scratch.curr_time, strole._rc.scratch.recency_decay,
|
||||
focal_pt, n_count)
|
||||
for n in results:
|
||||
n.last_accessed = strole._rc.scratch.curr_time
|
||||
|
||||
retrieved[focal_pt] = results
|
||||
retrieved[focal_pt] = results
|
||||
|
||||
|
||||
def top_highest_x_values(d, x):
|
||||
|
|
@ -159,4 +162,4 @@ def normalize_score_floats(score_list, target_min, target_max):
|
|||
score_list[i]['relevance'] = relevance_list[i]
|
||||
score_list[i]['recency'] = recency_list[i]
|
||||
|
||||
return score_list
|
||||
return score_list
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from ..utils.utils import check_if_file_exists
|
||||
from examples.st_game.utils.utils import check_if_file_exists
|
||||
|
||||
|
||||
class Scratch:
|
||||
def __init__(self, f_saved):
|
||||
def __init__(self):
|
||||
# 类别1:人物超参
|
||||
self.vision_r = 4
|
||||
self.att_bandwidth = 3
|
||||
|
|
@ -77,6 +77,7 @@ class Scratch:
|
|||
self.act_path_set = False
|
||||
self.planned_path = []
|
||||
|
||||
def set_scratch_path(self, f_saved: str):
|
||||
if check_if_file_exists(f_saved):
|
||||
# If we have a bootstrap file, load that here.
|
||||
scratch_load = json.load(open(f_saved))
|
||||
|
|
@ -510,4 +511,3 @@ class Scratch:
|
|||
minute = curr_min_sum % 60
|
||||
ret += f"{hour:02}:{minute:02} || {row[0]}\n"
|
||||
return ret
|
||||
|
||||
|
|
@ -10,8 +10,10 @@ import os
|
|||
|
||||
|
||||
class MemoryTree:
|
||||
def __init__(self, f_saved: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.tree = {}
|
||||
|
||||
def set_mem_path(self, f_saved: str):
|
||||
if os.path.isfile(f_saved) and os.path.exists(f_saved):
|
||||
with open(f_saved) as f:
|
||||
self.tree = json.load(f)
|
||||
|
|
|
|||
3
examples/st_game/plan/__init__.py
Normal file
3
examples/st_game/plan/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -6,11 +6,11 @@ from typing import Union, Tuple
|
|||
|
||||
from metagpt.logs import logger
|
||||
|
||||
from ..maze import Maze
|
||||
from ..roles.st_role import STRole
|
||||
from ..memory.retrieve import new_retrieve
|
||||
from ..actions.agent_chat_sum_rel import AgentChatSumRel
|
||||
from ..actions.gen_iter_chat_utt import GenIterChatUTT
|
||||
from examples.st_game.maze import Maze
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.memory.retrieve import new_agent_retrieve
|
||||
from examples.st_game.actions.agent_chat_sum_rel import AgentChatSumRel
|
||||
from examples.st_game.actions.gen_iter_chat_utt import GenIterChatUTT
|
||||
|
||||
|
||||
def agent_conversation(maze: Maze, init_role: STRole, target_role: STRole) -> list[str]:
|
||||
|
|
@ -23,7 +23,7 @@ def agent_conversation(maze: Maze, init_role: STRole, target_role: STRole) -> li
|
|||
target_scratch = target_role._rc.scratch
|
||||
|
||||
focal_points = [f"{target_scratch.name}"]
|
||||
retrieved = new_retrieve(init_role, focal_points, 50)
|
||||
retrieved = new_agent_retrieve(init_role, focal_points, 50)
|
||||
relationship = generate_summarize_agent_relationship(init_role, target_role, retrieved)
|
||||
print("-------- relationship: ", relationship)
|
||||
last_chat = ""
|
||||
|
|
@ -36,7 +36,7 @@ def agent_conversation(maze: Maze, init_role: STRole, target_role: STRole) -> li
|
|||
else:
|
||||
focal_points = [f"{relationship}",
|
||||
f"{target_scratch.name} is {target_scratch.act_description}"]
|
||||
retrieved = new_retrieve(init_role, focal_points, 15)
|
||||
retrieved = new_agent_retrieve(init_role, focal_points, 15)
|
||||
utt, end = generate_one_utterance(maze, init_role, target_role, retrieved, curr_chat)
|
||||
|
||||
curr_chat += [[scratch.name, utt]]
|
||||
|
|
@ -44,7 +44,7 @@ def agent_conversation(maze: Maze, init_role: STRole, target_role: STRole) -> li
|
|||
break
|
||||
|
||||
focal_points = [f"{scratch.name}"]
|
||||
retrieved = new_retrieve(target_role, focal_points, 50)
|
||||
retrieved = new_agent_retrieve(target_role, focal_points, 50)
|
||||
relationship = generate_summarize_agent_relationship(target_role, init_role, retrieved)
|
||||
print("-------- relationship: ", relationship)
|
||||
last_chat = ""
|
||||
|
|
@ -57,7 +57,7 @@ def agent_conversation(maze: Maze, init_role: STRole, target_role: STRole) -> li
|
|||
else:
|
||||
focal_points = [f"{relationship}",
|
||||
f"{scratch.name} is {scratch.act_description}"]
|
||||
retrieved = new_retrieve(target_role, focal_points, 15)
|
||||
retrieved = new_agent_retrieve(target_role, focal_points, 15)
|
||||
utt, end = generate_one_utterance(maze, target_role, init_role, retrieved, curr_chat)
|
||||
|
||||
curr_chat += [[target_scratch.name, utt]]
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ from typing import Union, Tuple
|
|||
from datetime import datetime
|
||||
import math
|
||||
|
||||
from ..maze import Maze
|
||||
from ..plan.converse import agent_conversation
|
||||
from ..roles.st_role import STRole
|
||||
from ..actions.decide_to_talk import DecideToTalk
|
||||
from ..actions.summarize_conv import SummarizeConv
|
||||
from ..actions.new_decomp_schedule import NewDecompSchedule
|
||||
from examples.st_game.maze import Maze
|
||||
from examples.st_game.plan.converse import agent_conversation
|
||||
from examples.st_game.roles.st_role import STRole
|
||||
from examples.st_game.actions.decide_to_talk import DecideToTalk
|
||||
from examples.st_game.actions.summarize_conv import SummarizeConv
|
||||
from examples.st_game.actions.new_decomp_schedule import NewDecompSchedule
|
||||
|
||||
|
||||
def plan(role: STRole, maze: Maze, roles: list[STRole], new_day: bool, retrieved: dict):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : 调用Prompts中模板,实现相关Action
|
||||
|
||||
from wrapper_prompt import special_response_generate, prompt_generate
|
||||
from memory.scratch import Scratch
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
import json
|
||||
|
||||
from examples.st_game.prompts.wrapper_prompt import special_response_generate, prompt_generate
|
||||
from examples.st_game.memory.scratch import Scratch
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
|
||||
|
||||
def get_poignancy_action(scratch: Scratch, content: BasicMemory.content) -> str:
|
||||
"""
|
||||
|
|
@ -31,7 +32,8 @@ def get_poignancy_action(scratch: Scratch, content: BasicMemory.content) -> str:
|
|||
return str(poi_dict['poignancy']) # 将返回值强制转换为字符串
|
||||
except json.JSONDecodeError as e:
|
||||
return poignancy
|
||||
|
||||
|
||||
|
||||
def get_poignancy_chat(scratch: Scratch, content: BasicMemory.content) -> str:
|
||||
"""
|
||||
衡量会话心酸度
|
||||
|
|
|
|||
|
|
@ -1,12 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : reflection module
|
||||
|
||||
from metagpt.reflect import agent_reflect
|
||||
from metagpt.reflect import ga_prompt_generator
|
||||
__all__ = [
|
||||
"agent_reflect",
|
||||
"LongTermMemory",
|
||||
"ga_po"
|
||||
"ga_prompt_generator"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -37,19 +37,19 @@ async def final_response(prompt, special_instruction, example_output=None):
|
|||
|
||||
def prompt_generate(curr_input, prompt_lib_file):
|
||||
"""
|
||||
Takes in the current input (e.g. comment that you want to classifiy) and
|
||||
Takes in the current input (e.g. comment that you want to classifiy) and
|
||||
the path to a prompt file. The prompt file contains the raw str prompt that
|
||||
will be used, which contains the following substr: !<INPUT>! -- this
|
||||
function replaces this substr with the actual curr_input to produce the
|
||||
final promopt that will be sent to the GPT3 server.
|
||||
will be used, which contains the following substr: !<INPUT>! -- this
|
||||
function replaces this substr with the actual curr_input to produce the
|
||||
final promopt that will be sent to the GPT3 server.
|
||||
ARGS:
|
||||
curr_input: the input we want to feed in (IF THERE ARE MORE THAN ONE
|
||||
INPUT, THIS CAN BE A LIST.)
|
||||
prompt_lib_file: the path to the promopt file.
|
||||
RETURNS:
|
||||
a str prompt that will be sent to OpenAI's GPT server.
|
||||
prompt_lib_file: the path to the promopt file.
|
||||
RETURNS:
|
||||
a str prompt that will be sent to OpenAI's GPT server.
|
||||
"""
|
||||
if type(curr_input) is type("string"):
|
||||
if isinstance(curr_input, str):
|
||||
curr_input = [curr_input]
|
||||
curr_input = [str(i) for i in curr_input]
|
||||
|
||||
|
|
@ -62,23 +62,3 @@ def prompt_generate(curr_input, prompt_lib_file):
|
|||
prompt = prompt.split(
|
||||
"<commentblockmarker>###</commentblockmarker>")[1]
|
||||
return prompt.strip()
|
||||
|
||||
# 使用OpenAI embedding库进行存储
|
||||
|
||||
|
||||
def embedding(query):
|
||||
"""
|
||||
Generates an embedding for the given query.
|
||||
|
||||
Args:
|
||||
query (str): The text query to be embedded.
|
||||
|
||||
Returns:
|
||||
str: The embedding key generated for the query.
|
||||
"""
|
||||
embedding_result = openai.Embedding.create(
|
||||
model="text-embedding-ada-002",
|
||||
input=query
|
||||
)
|
||||
embedding_key = embedding_result['data'][0]["embedding"]
|
||||
return embedding_key
|
||||
|
|
|
|||
|
|
@ -1,23 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Reflect function
|
||||
import datetime
|
||||
import random
|
||||
import sys
|
||||
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
from ..roles.st_role import STRole
|
||||
from ..utils.utils import get_embedding
|
||||
import datetime
|
||||
|
||||
from metagpt.logs import logger
|
||||
from ..actions.run_reflect_action import (
|
||||
|
||||
from examples.st_game.utils.utils import get_embedding
|
||||
from examples.st_game.actions.run_reflect_action import (
|
||||
AgentFocusPt, AgentInsightAndGuidance, AgentEventTriple,
|
||||
AgentEventPoignancy, AgentChatPoignancy, AgentPlanThoughtOnConvo,
|
||||
AgentMemoryOnConvo
|
||||
)
|
||||
|
||||
|
||||
def generate_focal_points(role: STRole, n=3):
|
||||
def generate_focal_points(role: "STRole", n=3):
|
||||
nodes = [
|
||||
[i.last_accessed, i] for i in
|
||||
role._rc.memory.event_list + role._rc.memory.thought_list
|
||||
|
|
@ -65,7 +62,7 @@ def generate_action_event_triple(act_desp, role):
|
|||
return AgentEventTriple(act_desp, role)
|
||||
|
||||
|
||||
def generate_poig_score(role: STRole, event_type, description):
|
||||
def generate_poig_score(role: "STRole", event_type, description):
|
||||
if "is idle" in description:
|
||||
return 1
|
||||
|
||||
|
|
@ -89,7 +86,7 @@ def generate_memo_on_convo(role, all_utt):
|
|||
|
||||
|
||||
# Done
|
||||
def run_reflect(role: STRole):
|
||||
def run_reflect(role: "STRole"):
|
||||
"""
|
||||
Run the actual reflection. We generate the focal points, retrieve any
|
||||
relevant nodes, and generate thoughts and insights.
|
||||
|
|
@ -128,7 +125,7 @@ def run_reflect(role: STRole):
|
|||
|
||||
|
||||
# Done
|
||||
def reflection_trigger(role: STRole):
|
||||
def reflection_trigger(role: "STRole"):
|
||||
"""
|
||||
Given the current role, determine whether the role should run a
|
||||
reflection.
|
||||
|
|
@ -155,7 +152,7 @@ def reflection_trigger(role: STRole):
|
|||
|
||||
|
||||
# Done
|
||||
def reset_reflection_counter(role: STRole):
|
||||
def reset_reflection_counter(role: "STRole"):
|
||||
"""
|
||||
We reset the counters used for the reflection trigger.
|
||||
|
||||
|
|
@ -170,7 +167,7 @@ def reset_reflection_counter(role: STRole):
|
|||
|
||||
|
||||
# Question 1 chat函数
|
||||
def reflect(role: STRole):
|
||||
def reflect(role: "STRole"):
|
||||
"""
|
||||
The main reflection module for the role. We first check if the trigger
|
||||
conditions are met, and if so, run the reflection and reset any of the
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ import asyncio
|
|||
import json
|
||||
import time
|
||||
from metagpt.logs import logger
|
||||
from ..prompts.wrapper_prompt import special_response_generate
|
||||
from ..memory.agent_memory import BasicMemory
|
||||
|
||||
from examples.st_game.prompts.wrapper_prompt import special_response_generate
|
||||
from examples.st_game.memory.agent_memory import BasicMemory
|
||||
|
||||
|
||||
async def agent_reflect(memories_list):
|
||||
|
|
@ -21,7 +22,7 @@ async def agent_reflect(memories_list):
|
|||
B = await generate_insights_and_evidence(memories_list, question=i)
|
||||
|
||||
|
||||
async def generate_focus_point(memories_list: list[MemoryBasic], n=3):
|
||||
async def generate_focus_point(memories_list: list[BasicMemory], n=3):
|
||||
"""
|
||||
生成关注点函数:根据记忆列表生成关注点
|
||||
"""
|
||||
|
|
@ -47,7 +48,7 @@ async def generate_focus_point(memories_list: list[MemoryBasic], n=3):
|
|||
return out
|
||||
|
||||
|
||||
async def generate_insights_and_evidence(memories_list: list[MemoryBasic], question: str, n=5):
|
||||
async def generate_insights_and_evidence(memories_list: list[BasicMemory], question: str, n=5):
|
||||
"""
|
||||
生成洞察和证据函数:根据问题生成洞察和证据
|
||||
"""
|
||||
|
|
@ -68,7 +69,7 @@ async def generate_insights_and_evidence(memories_list: list[MemoryBasic], quest
|
|||
try:
|
||||
insight_list = json.loads(ret)
|
||||
for insight, index in insight_list:
|
||||
agent.memory_list.append(MemoryBasic(
|
||||
agent.memory_list.append(BasicMemory(
|
||||
time.time(), None, insight, None, None))
|
||||
return insight_list
|
||||
except:
|
||||
|
|
|
|||
3
examples/st_game/roles/__init__.py
Normal file
3
examples/st_game/roles/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -23,24 +23,24 @@ from metagpt.roles.role import Role, RoleContext
|
|||
from metagpt.schema import Message
|
||||
from metagpt.logs import logger
|
||||
|
||||
from ..memory.agent_memory import AgentMemory, BasicMemory
|
||||
from ..memory.spatial_memory import MemoryTree
|
||||
from ..actions.dummy_action import DummyAction
|
||||
from ..actions.user_requirement import UserRequirement
|
||||
from ..maze_environment import MazeEnvironment
|
||||
from ..memory.retrieve import new_agent_retrieve
|
||||
from ..memory.scratch import Scratch
|
||||
from ..utils.utils import get_embedding, generate_poig_score, path_finder
|
||||
from ..utils.const import collision_block_id
|
||||
from ..reflect.st_reflect import agent_reflect
|
||||
from ..utils.mg_ga_transform import save_movement, get_role_environment
|
||||
from examples.st_game.memory.agent_memory import AgentMemory, BasicMemory
|
||||
from examples.st_game.memory.spatial_memory import MemoryTree
|
||||
from examples.st_game.actions.dummy_action import DummyAction, DummyMessage
|
||||
from examples.st_game.actions.user_requirement import UserRequirement
|
||||
from examples.st_game.maze_environment import MazeEnvironment
|
||||
from examples.st_game.memory.retrieve import new_agent_retrieve
|
||||
from examples.st_game.memory.scratch import Scratch
|
||||
from examples.st_game.utils.utils import get_embedding, path_finder
|
||||
from examples.st_game.utils.const import collision_block_id, STORAGE_PATH
|
||||
from examples.st_game.reflect.reflect import generate_poig_score
|
||||
from examples.st_game.utils.mg_ga_transform import save_movement, get_role_environment
|
||||
|
||||
|
||||
class STRoleContext(RoleContext):
|
||||
env: 'MazeEnvironment' = Field(default=MazeEnvironment)
|
||||
memory: AgentMemory = Field(default=AgentMemory)
|
||||
scratch: Scratch = Field(default=Scratch)
|
||||
spatial_memory: MemoryTree = Field(default=MemoryTree)
|
||||
env: 'MazeEnvironment' = Field(default_factory=MazeEnvironment)
|
||||
memory: AgentMemory = Field(default_factory=AgentMemory)
|
||||
scratch: Scratch = Field(default_factory=Scratch)
|
||||
spatial_memory: MemoryTree = Field(default_factory=MemoryTree)
|
||||
|
||||
|
||||
class STRole(Role):
|
||||
|
|
@ -65,9 +65,18 @@ class STRole(Role):
|
|||
self.role_tile = (0, 0)
|
||||
self.game_obj_cleanup = dict()
|
||||
|
||||
self._rc = STRoleContext()
|
||||
super(STRole, self).__init__(name=name,
|
||||
profile=profile)
|
||||
self._rc = STRoleContext()
|
||||
memory_saved = str(STORAGE_PATH.joinpath(f"{sim_code}/personas/{self.name}/"
|
||||
f"bootstrap_memory/associative_memory"))
|
||||
self._rc.memory.set_mem_path(memory_saved)
|
||||
sp_mem_saved = str(STORAGE_PATH.joinpath(f"{sim_code}/personas/{self.name}/"
|
||||
f"bootstrap_memory/spatial_memory.json"))
|
||||
self._rc.spatial_memory.set_mem_path(f_saved=sp_mem_saved)
|
||||
scratch_f_saved = str(STORAGE_PATH.joinpath(f"{sim_code}/personas/{self.name}/"
|
||||
f"bootstrap_memory/scratch.json"))
|
||||
self._rc.scratch.set_scratch_path(f_saved=scratch_f_saved)
|
||||
|
||||
self._init_actions([])
|
||||
|
||||
|
|
@ -104,6 +113,19 @@ class STRole(Role):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def _observe(self) -> int:
|
||||
if not self._rc.env:
|
||||
return 0
|
||||
|
||||
observed = self._rc.env.memory.get_by_actions(self._rc.watch)
|
||||
self._rc.news = self._rc.memory.remember(observed)
|
||||
if len(self._rc.news) == 1 and isinstance(self._rc.news[0], UserRequirement):
|
||||
# add inner voice
|
||||
# TODO
|
||||
logger.warning(f"Role: {self.name} add inner voice: {self._rc.news[0].content}")
|
||||
|
||||
return 1 # always return 1 to execute role's `_react`
|
||||
|
||||
async def observe(self) -> list[BasicMemory]:
|
||||
# TODO observe info from maze_env
|
||||
"""
|
||||
|
|
@ -247,7 +269,7 @@ class STRole(Role):
|
|||
|
||||
async def retrieve(self, focus_points, n=30):
|
||||
# TODO retrieve memories from agent_memory
|
||||
retrieve_memories = new_agent_retrieve(self,focus_points,n)
|
||||
retrieve_memories = new_agent_retrieve(self, focus_points, n)
|
||||
return retrieve_memories
|
||||
|
||||
async def plan(self):
|
||||
|
|
@ -435,7 +457,7 @@ class STRole(Role):
|
|||
ret = self.update_role_env()
|
||||
if not ret:
|
||||
# TODO add message
|
||||
return
|
||||
return DummyMessage()
|
||||
|
||||
# TODO observe
|
||||
# get maze_env from self._rc.env, and observe env info
|
||||
|
|
@ -443,20 +465,23 @@ class STRole(Role):
|
|||
# TODO retrieve, use self._rc.memory 's retrieve functions
|
||||
|
||||
# TODO plan
|
||||
plan = self.plan()
|
||||
|
||||
# TODO reflect
|
||||
|
||||
# TODO execute(feed-back into maze_env)
|
||||
next_tile, pronunciatio, description = self.execute(plan)
|
||||
role_move = {
|
||||
"movement": next_tile,
|
||||
"pronunciatio": pronunciatio,
|
||||
"description": description,
|
||||
"chat": self.scratch.chat
|
||||
}
|
||||
save_movement(self.name, role_move, step=self.step, sim_code=self.sim_code, curr_time=self.curr_time)
|
||||
# plan = self.plan()
|
||||
#
|
||||
# # TODO reflect
|
||||
#
|
||||
# # TODO execute(feed-back into maze_env)
|
||||
# next_tile, pronunciatio, description = self.execute(plan)
|
||||
# role_move = {
|
||||
# "movement": next_tile,
|
||||
# "pronunciatio": pronunciatio,
|
||||
# "description": description,
|
||||
# "chat": self.scratch.chat
|
||||
# }
|
||||
# save_movement(self.name, role_move, step=self.step, sim_code=self.sim_code, curr_time=self.curr_time)
|
||||
|
||||
# step update
|
||||
logger.info(f"Role: {self.name} run at {self.step} step on {self.curr_time}")
|
||||
self.step += 1
|
||||
self.curr_time += datetime.timedelta(seconds=self.sec_per_step)
|
||||
|
||||
return DummyMessage()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from stanford_town import StanfordTown
|
|||
from roles.st_role import STRole
|
||||
from utils.mg_ga_transform import get_reverie_meta
|
||||
from utils.const import STORAGE_PATH
|
||||
from utils.utils import copy_folder
|
||||
|
||||
|
||||
async def startup(idea: str,
|
||||
|
|
@ -16,6 +17,8 @@ async def startup(idea: str,
|
|||
sim_code: str,
|
||||
investment: float = 30.0,
|
||||
n_round: int = 500):
|
||||
# copy `storage/{fork_sim_code}` to `storage/{sim_code}`
|
||||
copy_folder(str(STORAGE_PATH.joinpath(fork_sim_code)), str(STORAGE_PATH.joinpath(sim_code)))
|
||||
|
||||
# get role names from `storage/{simulation_name}/reverie/meta.json` and then init roles
|
||||
reverie_meta = get_reverie_meta(fork_sim_code)
|
||||
|
|
|
|||
3
examples/st_game/tests/actions/__init__.py
Normal file
3
examples/st_game/tests/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -2,6 +2,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/summarize_conv
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__) + "./../../../"))
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
from st_game.actions.summarize_conv import SummarizeConv
|
||||
|
|
|
|||
3
examples/st_game/utils/__init__.py
Normal file
3
examples/st_game/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -6,8 +6,8 @@ import json
|
|||
|
||||
from metagpt.logs import logger
|
||||
|
||||
from .const import STORAGE_PATH
|
||||
from .utils import read_json_file, write_json_file
|
||||
from examples.st_game.utils.const import STORAGE_PATH
|
||||
from examples.st_game.utils.utils import read_json_file, write_json_file
|
||||
|
||||
|
||||
def get_reverie_meta(sim_code: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : utils
|
||||
|
||||
from typing import Any, Union
|
||||
import os
|
||||
import json
|
||||
import openai
|
||||
from pathlib import Path
|
||||
import csv
|
||||
from ..prompts.run_gpt_prompts import get_poignancy_action, get_poignancy_chat
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import openai
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding=None) -> list[Any]:
|
||||
|
|
@ -37,23 +41,24 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
|||
RETURNS:
|
||||
List of list where the component lists are the rows of the file.
|
||||
"""
|
||||
logger.info(f"start read csv: {curr_file}")
|
||||
if not header:
|
||||
analysis_list = []
|
||||
with open(curr_file) as f_analysis_file:
|
||||
data_reader = csv.reader(f_analysis_file, delimiter=",")
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
return analysis_list
|
||||
else:
|
||||
analysis_list = []
|
||||
with open(curr_file) as f_analysis_file:
|
||||
data_reader = csv.reader(f_analysis_file, delimiter=",")
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
for count, row in enumerate(data_reader):
|
||||
if strip_trail:
|
||||
row = [i.strip() for i in row]
|
||||
analysis_list += [row]
|
||||
return analysis_list[0], analysis_list[1:]
|
||||
|
||||
|
||||
|
|
@ -65,15 +70,6 @@ def get_embedding(text, model: str = "text-embedding-ada-002"):
|
|||
input=[text], model=model)['data'][0]['embedding']
|
||||
|
||||
|
||||
def generate_poig_score(scratch, event_type, description):
|
||||
if "is idle" in description:
|
||||
return 1
|
||||
if event_type == "action":
|
||||
return get_poignancy_action(scratch, description)[0]
|
||||
elif event_type == "chat":
|
||||
return get_poignancy_chat(scratch, description)[0]
|
||||
|
||||
|
||||
def extract_first_json_dict(data_str: str) -> Union[None, dict]:
|
||||
# Find the first occurrence of a JSON object within the string
|
||||
start_idx = data_str.find("{")
|
||||
|
|
@ -178,6 +174,7 @@ def path_finder(maze: "Maze", start: list[int], end: list[int], collision_block_
|
|||
|
||||
return path
|
||||
|
||||
|
||||
def create_folder_if_not_there(curr_path):
|
||||
"""
|
||||
Checks if a folder in the curr_path exists. If it does not exist, creates
|
||||
|
|
@ -207,6 +204,7 @@ def create_folder_if_not_there(curr_path):
|
|||
|
||||
return False
|
||||
|
||||
|
||||
def find_filenames(path_to_dir, suffix=".csv"):
|
||||
"""
|
||||
Given a directory, find all files that end with the provided suffix and
|
||||
|
|
@ -219,4 +217,21 @@ def find_filenames(path_to_dir, suffix=".csv"):
|
|||
"""
|
||||
filenames = os.listdir(path_to_dir)
|
||||
return [path_to_dir + "/" + filename
|
||||
for filename in filenames if filename.endswith(suffix)]
|
||||
for filename in filenames if filename.endswith(suffix)]
|
||||
|
||||
|
||||
def check_if_file_exists(curr_file: str):
|
||||
return Path(curr_file).exists()
|
||||
|
||||
|
||||
def copy_folder(src_folder: str, dest_folder: str):
|
||||
try:
|
||||
if Path(dest_folder).exists():
|
||||
logger.warning(f"{dest_folder} exist, start to remove.")
|
||||
shutil.rmtree(dest_folder)
|
||||
shutil.copytree(src_folder, dest_folder)
|
||||
except OSError as exc: # python >2.5
|
||||
if exc.errno in (errno.ENOTDIR, errno.EINVAL):
|
||||
shutil.copy(src_folder, dest_folder)
|
||||
else:
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue