This commit is contained in:
geekan 2024-01-10 22:02:44 +08:00
parent 8af1488613
commit 4de8fa3682
15 changed files with 33 additions and 24 deletions

Binary file not shown.

View file

@ -63,7 +63,7 @@ class DebugError(Action):
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.file_repo.get_file(
filename=self.i_context.code_filename, relative_path=self.i_context.src_workspace
filename=self.i_context.code_filename, relative_path=self.context.src_workspace
)
if not code_doc:
return ""

View file

@ -178,7 +178,7 @@ class WebBrowseAndSummarize(Action):
i_context: Optional[str] = None
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: Optional[WebBrowserEngine] = None
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
def __init__(self, **kwargs):
super().__init__(**kwargs)

View file

@ -45,7 +45,7 @@ class TalkAction(Action):
language = self.language
prompt += (
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
f"{self.context}"
f"{self.i_context}"
)
logger.debug(f"PROMPT: {prompt}")
return prompt
@ -57,7 +57,7 @@ class TalkAction(Action):
"{history}": self.history_summary or "",
"{knowledge}": self.knowledge or "",
"{language}": self.language,
"{ask}": self.context,
"{ask}": self.i_context,
}
prompt = TalkActionPrompt.FORMATION_LOOSE
for k, v in kvs.items():
@ -88,7 +88,7 @@ class TalkAction(Action):
format_msgs.append({"role": "assistant", "content": self.knowledge})
if self.history_summary:
format_msgs.append({"role": "assistant", "content": self.history_summary})
return self.context, format_msgs, system_msgs
return self.i_context, format_msgs, system_msgs
async def run(self, with_message=None, **kwargs) -> Message:
msg, format_msgs, system_msgs = self.aask_args

View file

@ -43,7 +43,7 @@ class ProductManager(Role):
self._set_state(1)
else:
self._set_state(0)
self.context.config.git_reinit = False
self.config.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)

View file

@ -17,7 +17,6 @@
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config2 import Config
from metagpt.const import (
MESSAGE_ROUTE_TO_NONE,
TEST_CODES_FILE_REPO,
@ -48,10 +47,6 @@ class QaEngineer(Role):
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
self.test_round = 0
@property
def config(self) -> Config:
return self.context.config
async def _write_test(self, message: Message) -> None:
src_file_repo = self.context.git_repo.new_file_repository(self.context.src_workspace)
changed_files = set(src_file_repo.changed_files.keys())

View file

@ -42,7 +42,7 @@ class SearchEngine:
def __init__(
self,
engine: Optional[SearchEngineType] = None,
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
):
if engine == SearchEngineType.SERPAPI_GOOGLE:

View file

@ -4,6 +4,7 @@
import json
from pathlib import Path
from metagpt.config2 import config
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
from metagpt.utils.common import awrite
@ -281,6 +282,6 @@ class UTGenerator:
"""Choose based on different calling methods"""
result = ""
if self.chatgpt_method == "API":
result = await GPTAPI().aask_code(messages=messages)
result = await GPTAPI(config.get_llm_config()).aask_code(messages=messages)
return result

View file

@ -15,7 +15,7 @@ from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType | None = None,
engine: WebBrowserEngineType | None = WebBrowserEngineType.PLAYWRIGHT,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
if engine is None:

View file

@ -146,7 +146,8 @@ def setup_and_teardown_git_repo(request):
# Destroy git repo at the end of the test session.
def fin():
CONTEXT.git_repo.delete_repository()
if CONTEXT.git_repo:
CONTEXT.git_repo.delete_repository()
# Register the function for destroying the environment.
request.addfinalizer(fin)

File diff suppressed because one or more lines are too long

View file

@ -14,7 +14,6 @@ from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.utils.common import aread
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import ChangeType
@ -23,7 +22,8 @@ async def test_rebuild():
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
graph_db_filename = Path(CONTEXT.git_repo.workdir.name).with_suffix(".json")
await FileRepository.save_file(
repo = CONTEXT.file_repo
await repo.save_file(
filename=str(graph_db_filename),
relative_path=GRAPH_REPO_FILE_REPO,
content=data,

View file

@ -62,7 +62,7 @@ async def test_react():
"goal": "Test",
"constraints": "constraints",
"desc": "desc",
"subscription": "start",
"address": "start",
}
]
@ -93,8 +93,8 @@ async def test_react():
await env.run()
assert role.is_idle
tag = uuid.uuid4().hex
role.subscribe({tag})
assert env.get_subscription(role) == {tag}
role.set_addresses({tag})
assert env.get_addresses(role) == {tag}
@pytest.mark.asyncio
@ -131,7 +131,7 @@ async def test_recover():
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
assert role.todo == any_to_name(MockAction)
assert role.first_action == any_to_name(MockAction)
rsp = await role.run()
assert rsp.cause_by == any_to_str(MockAction)

View file

@ -102,7 +102,7 @@ def test_message_serdeser():
new_message = Message.model_validate(message_dict)
assert new_message.content == message.content
assert new_message.instruct_content.model_dump() == message.instruct_content.model_dump()
assert new_message.instruct_content != message.instruct_content # TODO
assert new_message.instruct_content == message.instruct_content # TODO
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field3 == out_data["field3"]

View file

@ -22,7 +22,7 @@ async def async_mock_from_url(*args, **kwargs):
@pytest.mark.asyncio
@mock.patch("aioredis.from_url", return_value=async_mock_from_url())
async def test_redis():
async def test_redis(i):
redis = Config.default().redis
conn = Redis(redis)