mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
Merge pull request #939 from better629/feat_mock
Feat mock openai embed for document_store and memory UTs
This commit is contained in:
commit
b17cc3fa03
18 changed files with 134 additions and 70 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -6,6 +6,9 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
|
|
@ -14,8 +17,17 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
|
||||
return embeds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json():
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -24,7 +36,9 @@ async def test_search_json():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx():
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
|
|
@ -33,7 +47,9 @@ async def test_search_xlsx():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
|
|
|
|||
33
tests/metagpt/memory/mock_text_embed.py
Normal file
33
tests/metagpt/memory/mock_text_embed.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": np.zeros(shape=[1, dim]),
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": np.ones(shape=[1, dim]),
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
|
@ -4,20 +4,22 @@
|
|||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config2 import config
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
def test_ltm_search(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
|
|
@ -27,20 +29,20 @@ def test_ltm_search():
|
|||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = "Write a cli snake game"
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
|
@ -56,7 +58,7 @@ def test_ltm_search():
|
|||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = "Write a Battle City"
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
|
|
|||
|
|
@ -4,23 +4,25 @@
|
|||
@Desc : the unittests of metagpt/memory/memory_storage.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = "Write a cli snake game"
|
||||
def test_idea_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
|
|
@ -33,12 +35,12 @@ def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
@ -47,13 +49,17 @@ def test_idea_message():
|
|||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message():
|
||||
def test_actionout_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
content = text_embed_arr[4].get(
|
||||
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
)
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
|
@ -67,12 +73,14 @@ def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
|
|
|||
|
|
@ -142,6 +142,9 @@ def check_or_create_base_tag(project_path):
|
|||
# Initialize a Git repository
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
|
||||
# Check if the .gitignore exists. If it doesn't exist, create .gitignore and add the comment
|
||||
subprocess.run(f"echo # Ignore these files or directories > {'.gitignore'}", shell=True)
|
||||
|
||||
# Check if the 'base' tag exists
|
||||
check_base_tag_cmd = ["git", "show-ref", "--verify", "--quiet", "refs/tags/base"]
|
||||
if subprocess.run(check_base_tag_cmd).returncode == 0:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue