Merge pull request #357 from garylin2099/werewolf_game

Werewolf game
This commit is contained in:
garylin2099 2023-09-23 22:23:52 +08:00 committed by GitHub
commit cf365c8e82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 464 additions and 24 deletions

View file

@ -33,7 +33,7 @@ class Environment(BaseModel):
Add a role in the current environment
"""
role.set_env(self)
self.roles[role.profile] = role
self.roles[str(role._setting)] = role
def add_roles(self, roles: Iterable[Role]):
"""增加一批在当前环境的角色
@ -72,8 +72,8 @@ class Environment(BaseModel):
"""
return self.roles
def get_role(self, name: str) -> Role:
def get_role(self, role_setting: str) -> Role:
"""获得环境内的指定角色
get all the environment roles
"""
return self.roles.get(name, None)
return self.roles.get(role_setting, None)

View file

@ -42,21 +42,21 @@ class LongTermMemory(Memory):
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)
def remember(self, observed: list[Message], k=0) -> list[Message]:
def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""
remember the most similar k memories from observed Messages, return all when k=0
1. remember the short-term memory(stm) news
2. integrate the stm news with ltm(long-term memory) news
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
1. find the short-term memory(stm) news
2. furthermore, filter out similar messages based on ltm(long-term memory), get the final news
"""
stm_news = super(LongTermMemory, self).remember(observed, k=k) # shot-term memory news
stm_news = super(LongTermMemory, self).find_news(observed, k=k) # shot-term memory news
if not self.memory_storage.is_initialized:
# memory_storage hasn't initialized, use default `remember` to get stm_news
# memory_storage hasn't initialized, use default `find_news` to get stm_news
return stm_news
ltm_news: list[Message] = []
for mem in stm_news:
# integrate stm & ltm
mem_searched = self.memory_storage.search(mem)
# filter out messages similar to those seen previously in ltm, only keep fresh news
mem_searched = self.memory_storage.search_dissimilar(mem)
if len(mem_searched) > 0:
ltm_news.append(mem)
return ltm_news[-k:]

View file

@ -63,8 +63,8 @@ class Memory:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]
def remember(self, observed: list[Message], k=0) -> list[Message]:
"""remember the most recent k memories from observed Messages, return all when k=0"""
def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""find news (previously unseen messages) from the the most recent k memories, from all memories when k=0"""
already_observed = self.get(k)
news: list[Message] = []
for i in observed:

View file

@ -74,7 +74,7 @@ class MemoryStorage(FaissStore):
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
def search(self, message: Message, k=4) -> List[Message]:
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
"""search for dissimilar messages"""
if not self.store:
return []

View file

@ -137,6 +137,11 @@ class Role:
"""Get the role description (position)"""
return self._setting.profile
@property
def name(self):
"""Get the role name"""
return self._setting.name
def _get_prefix(self):
"""Get the role prefix"""
if self._setting.desc:
@ -185,9 +190,13 @@ class Role:
observed = self._rc.env.memory.get_by_actions(self._rc.watch)
self._rc.news = self._rc.memory.remember(observed) # remember recent exact or similar memories
self._rc.news = self._rc.memory.find_news(observed) # find news (previously unseen messages) from observed messages
for i in env_msgs:
if i.restricted_to != "" and self.profile not in i.restricted_to and self.name not in i.restricted_to:
# if the msg is not send to the whole audience ("") nor this role (self.profile or self.name),
# then this role should not be able to receive it and record it into its memory
continue
self.recv(i)
news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]

View file

@ -29,6 +29,7 @@ class Message:
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
restricted_to: str = field(default="")
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])