mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 11:56:24 +02:00
Merge branch 'main' into main
This commit is contained in:
commit
e44410b3ad
17 changed files with 532 additions and 10 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -153,6 +153,7 @@ allure-results
|
|||
docs/scripts/set_env.sh
|
||||
key.yaml
|
||||
output.json
|
||||
data
|
||||
data/output_add.json
|
||||
data.ms
|
||||
examples/nb/
|
||||
|
|
|
|||
|
|
@ -52,3 +52,6 @@ RPM: 10
|
|||
## Use SD service, based on https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
SD_URL: "YOUR_SD_URL"
|
||||
SD_T2I_API: "/sdapi/v1/txt2img"
|
||||
|
||||
#### for Execution
|
||||
#LONG_TERM_MEMORY: false
|
||||
|
|
|
|||
|
|
@ -42,7 +42,10 @@ FORMAT_EXAMPLE = '''
|
|||
---
|
||||
## Required Python third-party packages
|
||||
```python
|
||||
"""
|
||||
flask==1.1.2
|
||||
bcrypt==3.2.0
|
||||
"""
|
||||
```
|
||||
|
||||
## Required Other language third-party packages
|
||||
|
|
@ -110,7 +113,7 @@ class WriteTasks(Action):
|
|||
|
||||
# Write requirements.txt
|
||||
requirements_path = WORKSPACE_ROOT / ws_name / 'requirements.txt'
|
||||
requirements_path.write_text(rsp.instruct_content.dict().get("Required Python third-party packages"))
|
||||
requirements_path.write_text(rsp.instruct_content.dict().get("Required Python third-party packages").strip('"\n'))
|
||||
|
||||
async def run(self, context):
|
||||
prompt = PROMPT_TEMPLATE.format(context=context, format_example=FORMAT_EXAMPLE)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class Config(metaclass=Singleton):
|
|||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
if not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key:
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY first")
|
||||
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base:
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
|
|
@ -60,14 +61,20 @@ class Config(metaclass=Singleton):
|
|||
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
|
||||
self.deployment_id = self._get("DEPLOYMENT_ID")
|
||||
|
||||
self.claude_api_key = self._get('Anthropic_API_KEY')
|
||||
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
|
||||
self.serper_api_key = self._get("SERPER_API_KEY")
|
||||
self.google_api_key = self._get("GOOGLE_API_KEY")
|
||||
self.google_cse_id = self._get("GOOGLE_CSE_ID")
|
||||
self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)
|
||||
|
||||
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright"))
|
||||
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
|
||||
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
|
||||
|
||||
self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -32,3 +32,5 @@ UT_PY_PATH = UT_PATH / "files/ut/"
|
|||
API_QUESTIONS_PATH = UT_PATH / "files/question/"
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
TMP = PROJECT_ROOT / 'tmp'
|
||||
|
||||
MEM_TTL = 24 * 30 * 3600
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class FaissStore(LocalStore):
|
|||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.warning("Download data from http://pan.deepwisdomai.com/library/13ff7974-fbc7-40ab-bc10-041fdc97adbd/LLM/00_QCS-%E5%90%91%E9%87%8F%E6%95%B0%E6%8D%AE/qcs")
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
index = faiss.read_index(str(index_file))
|
||||
with open(str(store_file), "rb") as f:
|
||||
|
|
|
|||
|
|
@ -7,3 +7,5 @@
|
|||
"""
|
||||
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
|
||||
|
|
|
|||
71
metagpt/memory/longterm_memory.py
Normal file
71
metagpt/memory/longterm_memory.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of Long-term memory
|
||||
|
||||
from typing import Iterable, Type
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
|
||||
|
||||
class LongTermMemory(Memory):
|
||||
"""
|
||||
The Long-term memory for Roles
|
||||
- recover memory when it staruped
|
||||
- update memory when it changed
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_storage: MemoryStorage = MemoryStorage()
|
||||
super(LongTermMemory, self).__init__()
|
||||
self.rc = None # RoleContext
|
||||
self.msg_from_recover = False
|
||||
|
||||
def recover_memory(self, role_id: str, rc: "RoleContext"):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty')
|
||||
else:
|
||||
logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages '
|
||||
f'and has recovered them.')
|
||||
self.msg_from_recover = True
|
||||
self.add_batch(messages)
|
||||
self.msg_from_recover = False
|
||||
|
||||
def add(self, message: Message):
|
||||
super(LongTermMemory, self).add(message)
|
||||
for action in self.rc.watch:
|
||||
if message.cause_by == action and not self.msg_from_recover:
|
||||
# currently, only add role's watching messages to its memory_storage
|
||||
# and ignore adding messages from recover repeatedly
|
||||
self.memory_storage.add(message)
|
||||
|
||||
def remember(self, observed: list[Message], k=10) -> list[Message]:
|
||||
"""
|
||||
remember the most similar k memories from observed Messages, return all when k=0
|
||||
1. remember the short-term memory(stm) news
|
||||
2. integrate the stm news with ltm(long-term memory) news
|
||||
"""
|
||||
stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news
|
||||
if not self.memory_storage.is_initialized:
|
||||
# memory_storage hasn't initialized, use default `remember` to get stm_news
|
||||
return stm_news
|
||||
|
||||
ltm_news: list[Message] = []
|
||||
for mem in stm_news:
|
||||
# integrate stm & ltm
|
||||
mem_searched = self.memory_storage.search(mem)
|
||||
if len(mem_searched) > 0:
|
||||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
||||
def delete(self, message: Message):
|
||||
super(LongTermMemory, self).delete(message)
|
||||
# TODO delete message in memory_storage
|
||||
|
||||
def clear(self):
|
||||
super(LongTermMemory, self).clear()
|
||||
self.memory_storage.clean()
|
||||
|
|
@ -63,6 +63,16 @@ class Memory:
|
|||
"""Return the most recent k memories, return all when k=0"""
|
||||
return self.storage[-k:]
|
||||
|
||||
def remember(self, observed: list[Message], k=10) -> list[Message]:
|
||||
"""remember the most recent k memories from observed Messages, return all when k=0"""
|
||||
already_observed = self.get(k)
|
||||
news: list[Message] = []
|
||||
for i in observed:
|
||||
if i in already_observed:
|
||||
continue
|
||||
news.append(i)
|
||||
return news
|
||||
|
||||
def get_by_action(self, action: Type[Action]) -> list[Message]:
|
||||
"""Return all messages triggered by a specified Action"""
|
||||
return self.index[action]
|
||||
|
|
|
|||
106
metagpt/memory/memory_storage.py
Normal file
106
metagpt/memory/memory_storage.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of memory storage
|
||||
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import serialize_message, deserialize_message
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
||||
|
||||
class MemoryStorage(FaissStore):
|
||||
"""
|
||||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def recover_memory(self, role_id: str) -> List[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.store = self._load()
|
||||
messages = []
|
||||
if not self.store:
|
||||
# TODO init `self.store` under here with raw faiss api instead under `add`
|
||||
pass
|
||||
else:
|
||||
for _id, document in self.store.docstore._dict.items():
|
||||
messages.append(deserialize_message(document.metadata.get("message_ser")))
|
||||
self._initialized = True
|
||||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
|
||||
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
super(MemoryStorage, self).persist()
|
||||
logger.debug(f'Agent {self.role_id} persist memory into local')
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
""" add message into memory storage"""
|
||||
docs = [message.content]
|
||||
metadatas = [{"message_ser": serialize_message(message)}]
|
||||
if not self.store:
|
||||
# init Faiss
|
||||
self.store = self._write(docs, metadatas)
|
||||
self._initialized = True
|
||||
else:
|
||||
self.store.add_texts(texts=docs, metadatas=metadatas)
|
||||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search(self, message: Message, k=4) -> List[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
||||
resp = self.store.similarity_search_with_score(
|
||||
query=message.content,
|
||||
k=k
|
||||
)
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
for item, score in resp:
|
||||
# the smaller score means more similar relation
|
||||
if score < self.threshold:
|
||||
continue
|
||||
# convert search result into Memory
|
||||
metadata = item.metadata
|
||||
new_mem = deserialize_message(metadata.get("message_ser"))
|
||||
filtered_resp.append(new_mem)
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
index_fpath, storage_fpath = self._get_index_and_store_fname()
|
||||
if index_fpath and index_fpath.exists():
|
||||
index_fpath.unlink(missing_ok=True)
|
||||
if storage_fpath and storage_fpath.exists():
|
||||
storage_fpath.unlink(missing_ok=True)
|
||||
|
||||
self.store = None
|
||||
self._initialized = False
|
||||
|
|
@ -12,10 +12,11 @@ from typing import Iterable, Type
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
# from metagpt.environment import Environment
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory import Memory, LongTermMemory
|
||||
from metagpt.schema import Message
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
|
||||
|
|
@ -65,6 +66,7 @@ class RoleContext(BaseModel):
|
|||
"""角色运行时上下文"""
|
||||
env: 'Environment' = Field(default=None)
|
||||
memory: Memory = Field(default_factory=Memory)
|
||||
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
|
||||
state: int = Field(default=0)
|
||||
todo: Action = Field(default=None)
|
||||
watch: set[Type[Action]] = Field(default_factory=set)
|
||||
|
|
@ -72,6 +74,11 @@ class RoleContext(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def check(self, role_id: str):
|
||||
if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
|
||||
self.long_term_memory.recover_memory(role_id, self)
|
||||
self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
|
||||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""获得关注动作对应的信息"""
|
||||
|
|
@ -90,6 +97,7 @@ class Role:
|
|||
self._setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
|
||||
self._states = []
|
||||
self._actions = []
|
||||
self._role_id = str(self._setting)
|
||||
self._rc = RoleContext()
|
||||
|
||||
def _reset(self):
|
||||
|
|
@ -110,6 +118,8 @@ class Role:
|
|||
def _watch(self, actions: Iterable[Type[Action]]):
|
||||
"""监听对应的行为"""
|
||||
self._rc.watch.update(actions)
|
||||
# check RoleContext after adding watch actions
|
||||
self._rc.check(self._role_id)
|
||||
|
||||
def _set_state(self, state):
|
||||
"""Update the current state."""
|
||||
|
|
@ -174,13 +184,7 @@ class Role:
|
|||
|
||||
observed = self._rc.env.memory.get_by_actions(self._rc.watch)
|
||||
|
||||
already_observed = self._rc.memory.get()
|
||||
|
||||
news: list[Message] = []
|
||||
for i in observed:
|
||||
if i in already_observed:
|
||||
continue
|
||||
news.append(i)
|
||||
news = self._rc.memory.remember(observed) # remember recent exact or similar memories
|
||||
|
||||
for i in env_msgs:
|
||||
self.recv(i)
|
||||
|
|
|
|||
75
metagpt/utils/serialize.py
Normal file
75
metagpt/utils/serialize.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of serialization and deserialization
|
||||
|
||||
import copy
|
||||
from typing import Tuple, List, Type, Union, Dict
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pydantic import create_model
|
||||
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
||||
"""
|
||||
directly traverse the `properties` in the first level.
|
||||
schema structure likes
|
||||
```
|
||||
{
|
||||
"title":"prd",
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"Original Requirements":{
|
||||
"title":"Original Requirements",
|
||||
"type":"string"
|
||||
},
|
||||
},
|
||||
"required":[
|
||||
"Original Requirements",
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
mapping = dict()
|
||||
for field, property in schema['properties'].items():
|
||||
if property['type'] == 'string':
|
||||
mapping[field] = (str, ...)
|
||||
elif property['type'] == 'array' and property['items']['type'] == 'string':
|
||||
mapping[field] = (List[str], ...)
|
||||
elif property['type'] == 'array' and property['items']['type'] == 'array':
|
||||
# here only consider the `Tuple[str, str]` situation
|
||||
mapping[field] = (List[Tuple[str, str]], ...)
|
||||
return mapping
|
||||
|
||||
|
||||
def serialize_message(message: Message):
|
||||
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
|
||||
ic = message_cp.instruct_content
|
||||
if ic:
|
||||
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
|
||||
schema = ic.schema()
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
|
||||
message_cp.instruct_content = {
|
||||
'class': schema['title'],
|
||||
'mapping': mapping,
|
||||
'value': ic.dict()
|
||||
}
|
||||
msg_ser = pickle.dumps(message_cp)
|
||||
|
||||
return msg_ser
|
||||
|
||||
|
||||
def deserialize_message(message_ser: str) -> Message:
|
||||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
|
||||
mapping=ic['mapping'])
|
||||
ic_new = ic_obj(**ic['value'])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
return message
|
||||
|
|
@ -33,3 +33,5 @@ tqdm==4.64.0
|
|||
# selenium>4
|
||||
# webdriver_manager<3.9
|
||||
anthropic==0.3.6
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.5.0
|
||||
|
|
|
|||
3
tests/metagpt/memory/__init__.py
Normal file
3
tests/metagpt/memory/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
55
tests/metagpt/memory/test_longterm_memory.py
Normal file
55
tests/metagpt/memory/test_longterm_memory.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.memory import LongTermMemory
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
assert hasattr(CONFIG, "long_term_memory") is True
|
||||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
|
||||
role_id = 'UTUserLtm(Product Manager)'
|
||||
rc = RoleContext(watch=[BossRequirement])
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = 'Write a cli snake game'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
news = ltm.remember([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
news = ltm.remember([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
news = ltm.remember([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.remember([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.remember([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = 'Write a Battle City'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
news = ltm_new.remember([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
82
tests/metagpt/memory/test_memory_storage.py
Normal file
82
tests/metagpt/memory/test_memory_storage.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
|
||||
from typing import List
|
||||
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = 'Write a cli snake game'
|
||||
role_id = 'UTUser1(Product Manager)'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
|
||||
role_id = 'UTUser2(Architect)'
|
||||
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
|
||||
message = Message(content=content,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = 'The request is command-line interface (CLI) snake game'
|
||||
sim_message = Message(content=sim_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
|
||||
new_message = Message(content=new_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
96
tests/metagpt/utils/test_serialize.py
Normal file
96
tests/metagpt/utils/test_serialize.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of serialize
|
||||
|
||||
from typing import List, Tuple
|
||||
import pytest
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message
|
||||
|
||||
|
||||
def test_actionoutout_schema_to_mapping():
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (str, ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[str], ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'array',
|
||||
'minItems': 2,
|
||||
'maxItems': 2,
|
||||
'items': [
|
||||
{
|
||||
'type': 'string'
|
||||
},
|
||||
{
|
||||
'type': 'string'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[Tuple[str, str]], ...)
|
||||
|
||||
assert True, True
|
||||
|
||||
|
||||
def test_serialize_and_deserialize_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
|
||||
message = Message(content='prd demand',
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
|
||||
message_ser = serialize_message(message)
|
||||
|
||||
new_message = deserialize_message(message_ser)
|
||||
assert new_message.content == message.content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field1 == out_data['field1']
|
||||
Loading…
Add table
Add a link
Reference in a new issue