mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
feat: merge geekan:main
This commit is contained in:
commit
c6b9e234bf
24 changed files with 326 additions and 90 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,7 +1,7 @@
|
|||
### Python template
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
__pycache__
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : imitate_webpage.py
|
||||
"""
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ async def main():
|
|||
question = "What are the most interesting human facts?"
|
||||
|
||||
search = Config.default().search
|
||||
kwargs = {"api_key": search.api_key, "cse_id": search.cse_id, "proxy": None}
|
||||
kwargs = search.model_dump()
|
||||
await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ class ActionNode:
|
|||
|
||||
def compile_to(self, i: Dict, schema, kv_sep) -> str:
|
||||
if schema == "json":
|
||||
return json.dumps(i, indent=4)
|
||||
return json.dumps(i, indent=4, ensure_ascii=False)
|
||||
elif schema == "markdown":
|
||||
return dict_to_markdown(i, kv_sep=kv_sep)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import (
|
|||
STRUCTUAL_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import CodeParser, process_message, remove_comments
|
||||
from metagpt.utils.common import CodeParser, remove_comments
|
||||
|
||||
|
||||
class WriteAnalysisCode(Action):
|
||||
|
|
@ -50,7 +50,7 @@ class WriteAnalysisCode(Action):
|
|||
)
|
||||
|
||||
working_memory = working_memory or []
|
||||
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
|
||||
context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)
|
||||
|
||||
# LLM call
|
||||
if use_reflection:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
"""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
|
@ -18,3 +20,11 @@ class SearchConfig(YamlModel):
|
|||
api_key: str = ""
|
||||
cse_id: str = "" # for google
|
||||
search_func: Optional[Callable] = None
|
||||
params: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
|
|
@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel):
|
|||
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
|
||||
if not skill_yaml_file_name:
|
||||
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
|
||||
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
|
||||
data = await reader.read(-1)
|
||||
data = await aread(filename=skill_yaml_file_name)
|
||||
skill_data = yaml.safe_load(data)
|
||||
return SkillsDeclaration(**skill_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,28 @@ class BaseLLM(ABC):
|
|||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "system", "content": msg}
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
|
||||
return [self._system_msg(msg) for msg in msgs]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.generativeai as genai
|
||||
|
|
@ -16,9 +17,10 @@ from google.generativeai.types.generation_types import (
|
|||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -52,6 +54,10 @@ class GeminiLLM(BaseLLM):
|
|||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
if config.proxy:
|
||||
logger.info(f"Use proxy: {config.proxy}")
|
||||
os.environ["HTTP_PROXY"] = config.proxy
|
||||
os.environ["HTTP_PROXYS"] = config.proxy
|
||||
genai.configure(api_key=config.api_key)
|
||||
|
||||
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
|
||||
|
|
@ -62,6 +68,35 @@ class GeminiLLM(BaseLLM):
|
|||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "model", "parts": [msg]}
|
||||
|
||||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# REF: https://ai.google.dev/tutorials/python_quickstart
|
||||
# As a dictionary, the message requires `role` and `parts` keys.
|
||||
# The role in a conversation can either be the `user`, which provides the prompts,
|
||||
# or `model`, which provides the responses.
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "parts": [msg]})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "parts"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -30,12 +30,7 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
decode_image,
|
||||
log_and_reraise,
|
||||
process_message,
|
||||
)
|
||||
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -151,7 +146,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def _achat_completion_function(
|
||||
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
|
||||
) -> ChatCompletion:
|
||||
messages = process_message(messages)
|
||||
messages = self.format_msg(messages)
|
||||
kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union
|
|||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import chardet
|
||||
import loguru
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
|
@ -663,14 +664,21 @@ def role_raise_decorator(func):
|
|||
|
||||
|
||||
@handle_exception
|
||||
async def aread(filename: str | Path, encoding=None) -> str:
|
||||
async def aread(filename: str | Path, encoding="utf-8") -> str:
|
||||
"""Read file asynchronously."""
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
try:
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
except UnicodeDecodeError:
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
raw = await reader.read()
|
||||
result = chardet.detect(raw)
|
||||
detected_encoding = result["encoding"]
|
||||
content = raw.decode(detected_encoding)
|
||||
return content
|
||||
|
||||
|
||||
async def awrite(filename: str | Path, data: str, encoding=None):
|
||||
async def awrite(filename: str | Path, data: str, encoding="utf-8"):
|
||||
"""Write file asynchronously."""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -802,29 +810,6 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
return img
|
||||
|
||||
|
||||
def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
# 全部转成list
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# 转成list[dict]
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -13,9 +13,7 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -45,8 +43,7 @@ class DependencyFile:
|
|||
async def save(self):
|
||||
"""Save dependencies to the file asynchronously."""
|
||||
data = json.dumps(self._dependencies)
|
||||
async with aiofiles.open(str(self._filename), mode="w") as writer:
|
||||
await writer.write(data)
|
||||
await awrite(filename=self._filename, data=data)
|
||||
|
||||
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
|
||||
"""Update dependencies for a file asynchronously.
|
||||
|
|
|
|||
|
|
@ -14,11 +14,9 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.json_to_markdown import json_to_markdown
|
||||
|
||||
|
||||
|
|
@ -55,8 +53,7 @@ class FileRepository:
|
|||
pathname = self.workdir / filename
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = content if content else "" # avoid `argument must be str, not None` to make it continue
|
||||
async with aiofiles.open(str(pathname), mode="w") as writer:
|
||||
await writer.write(content)
|
||||
await awrite(filename=str(pathname), data=content)
|
||||
logger.info(f"save to: {str(pathname)}")
|
||||
|
||||
if dependencies is not None:
|
||||
|
|
|
|||
|
|
@ -9,11 +9,9 @@ import asyncio
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import check_cmd_exists
|
||||
from metagpt.utils.common import awrite, check_cmd_exists
|
||||
|
||||
|
||||
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
|
|
@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
|
||||
await f.write(mermaid_code)
|
||||
# tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
await awrite(filename=tmp, data=mermaid_code)
|
||||
|
||||
if engine == "nodejs":
|
||||
if check_cmd_exists(config.mermaid.path) != 0:
|
||||
|
|
|
|||
|
|
@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str:
|
|||
content (str): llm's output from `Role._think`
|
||||
"""
|
||||
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
|
||||
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
pattern = (
|
||||
r"(?<!-)[0-9]" # TODO find the number using a more proper method not just extract from content using pattern
|
||||
)
|
||||
matches = re.findall(pattern, content, re.DOTALL)
|
||||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
|
|
|
|||
140
metagpt/utils/tree.py
Normal file
140
metagpt/utils/tree.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/3/11
|
||||
@Author : mashenquan
|
||||
@File : tree.py
|
||||
@Desc : Implement the same functionality as the `tree` command.
|
||||
Example:
|
||||
>>> print_tree(".")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- cost_manager.py
|
||||
+-- __pycache__
|
||||
| +-- __init__.cpython-39.pyc
|
||||
| +-- redis.cpython-39.pyc
|
||||
| +-- singleton.cpython-39.pyc
|
||||
| +-- embedding.cpython-39.pyc
|
||||
| +-- make_sk_kernel.cpython-39.pyc
|
||||
| +-- file_repository.cpython-39.pyc
|
||||
+-- file.py
|
||||
+-- save_code.py
|
||||
+-- common.py
|
||||
+-- redis.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
|
||||
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
"""
|
||||
Recursively traverses the directory structure and prints it out in a tree-like format.
|
||||
|
||||
Args:
|
||||
root (str or Path): The root directory from which to start traversing.
|
||||
gitignore (str or Path): The filename of gitignore file.
|
||||
run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True,
|
||||
otherwise execute python code instead.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the directory tree.
|
||||
|
||||
Example:
|
||||
>>> tree(".")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- __pycache__
|
||||
| +-- __init__.cpython-39.pyc
|
||||
| +-- redis.cpython-39.pyc
|
||||
| +-- singleton.cpython-39.pyc
|
||||
+-- parse_docstring.py
|
||||
|
||||
>>> tree(".", gitignore="../../.gitignore")
|
||||
utils
|
||||
+-- serialize.py
|
||||
+-- project_repo.py
|
||||
+-- tree.py
|
||||
+-- mmdc_playwright.py
|
||||
+-- parse_docstring.py
|
||||
|
||||
>>> tree(".", gitignore="../../.gitignore", run_command=True)
|
||||
utils
|
||||
├── serialize.py
|
||||
├── project_repo.py
|
||||
├── tree.py
|
||||
├── mmdc_playwright.py
|
||||
└── parse_docstring.py
|
||||
|
||||
|
||||
"""
|
||||
root = Path(root).resolve()
|
||||
if run_command:
|
||||
return _execute_tree(root, gitignore)
|
||||
|
||||
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
|
||||
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
|
||||
v = _print_tree(dir_)
|
||||
return "\n".join(v)
|
||||
|
||||
|
||||
def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]:
|
||||
dir_ = {}
|
||||
for i in root.iterdir():
|
||||
if git_ignore_rules and git_ignore_rules(str(i)):
|
||||
continue
|
||||
try:
|
||||
if i.is_file():
|
||||
dir_[i.name] = {}
|
||||
else:
|
||||
dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules)
|
||||
except (FileNotFoundError, PermissionError, OSError):
|
||||
dir_[i.name] = {}
|
||||
return dir_
|
||||
|
||||
|
||||
def _print_tree(dir_: Dict[str:Dict]) -> List[str]:
|
||||
ret = []
|
||||
for name, children in dir_.items():
|
||||
ret.append(name)
|
||||
if not children:
|
||||
continue
|
||||
lines = _print_tree(children)
|
||||
for j, v in enumerate(lines):
|
||||
if v[0] not in ["+", " ", "|"]:
|
||||
ret = _add_line(ret)
|
||||
row = f"+-- {v}"
|
||||
else:
|
||||
row = f" {v}"
|
||||
ret.append(row)
|
||||
return ret
|
||||
|
||||
|
||||
def _add_line(rows: List[str]) -> List[str]:
|
||||
for i in range(len(rows) - 1, -1, -1):
|
||||
v = rows[i]
|
||||
if v[0] != " ":
|
||||
return rows
|
||||
rows[i] = "|" + v[1:]
|
||||
return rows
|
||||
|
||||
|
||||
def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
args = ["--gitfile", str(gitignore)] if gitignore else []
|
||||
try:
|
||||
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"tree exits with code {result.returncode}")
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise e
|
||||
2
setup.py
2
setup.py
|
|
@ -57,7 +57,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
|
|||
|
||||
setup(
|
||||
name="metagpt",
|
||||
version="0.7.4",
|
||||
version="0.7.6",
|
||||
description="The Multi-Agent Framework",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.const import TUTORIAL_PATH
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
|
|||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
|
||||
content = await reader.read()
|
||||
assert "pip" in content
|
||||
content = await aread(filename=filename)
|
||||
assert "pip" in content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Callable
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -53,14 +52,11 @@ async def test_search_engine(
|
|||
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-google-key"
|
||||
search_engine_config["cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serper-key"
|
||||
|
||||
async def test(search_engine):
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -125,9 +124,7 @@ class TestGetProjectRoot:
|
|||
async def test_parse_data_exception(self, filename, want):
|
||||
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
|
||||
assert pathname.exists()
|
||||
async with aiofiles.open(str(pathname), mode="r") as reader:
|
||||
data = await reader.read()
|
||||
|
||||
data = await aread(filename=pathname)
|
||||
result = OutputParser.parse_data(data=data)
|
||||
assert want in result
|
||||
|
||||
|
|
@ -198,12 +195,25 @@ class TestGetProjectRoot:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write(self):
|
||||
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
|
||||
await awrite(pathname, "ABC")
|
||||
data = await aread(pathname)
|
||||
assert data == "ABC"
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write_error_charset(self):
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
|
||||
content = "中国abc123\u27f6"
|
||||
await awrite(filename=pathname, data=content)
|
||||
data = await aread(filename=pathname)
|
||||
assert data == content
|
||||
|
||||
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
|
||||
await awrite(filename=pathname, data=content, encoding="gb2312")
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
assert data == content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -10,15 +10,14 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.common import awrite
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
async def mock_file(filename, content=""):
|
||||
async with aiofiles.open(str(filename), mode="w") as file:
|
||||
await file.write(content)
|
||||
await awrite(filename=filename, data=content)
|
||||
|
||||
|
||||
async def mock_repo(local_path) -> (GitRepository, Path):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
|
|
@ -46,7 +45,7 @@ async def test_s3(mocker):
|
|||
conn = S3(s3)
|
||||
object_name = "unittest.bak"
|
||||
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
|
||||
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname.unlink(missing_ok=True)
|
||||
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
|
||||
assert pathname.exists()
|
||||
|
|
@ -54,8 +53,7 @@ async def test_s3(mocker):
|
|||
assert url
|
||||
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
|
||||
assert bin_data
|
||||
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read()
|
||||
data = await aread(filename=__file__)
|
||||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
|
@ -69,8 +67,6 @@ async def test_s3(mocker):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
await reader.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
64
tests/metagpt/utils/test_tree.py
Normal file
64
tests/metagpt/utils/test_tree.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.tree import _print_tree, tree
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree_command(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules, run_command=True)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tree", "want"),
|
||||
[
|
||||
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
|
||||
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
|
||||
(
|
||||
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
|
||||
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
|
||||
),
|
||||
(
|
||||
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
|
||||
[
|
||||
"h",
|
||||
"+-- a",
|
||||
"| +-- b",
|
||||
"| | +-- e",
|
||||
"| | +-- f",
|
||||
"| | +-- g",
|
||||
"| +-- c",
|
||||
"| +-- d",
|
||||
"+-- i",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__print_tree(tree: dict, want: List[str]):
|
||||
v = _print_tree(tree)
|
||||
assert v == want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -8,7 +8,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
|||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import process_message
|
||||
|
||||
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
|
||||
|
||||
|
|
@ -105,7 +104,7 @@ class MockLLM(OriginalLLM):
|
|||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
msg_key = json.dumps(process_message(messages), ensure_ascii=False)
|
||||
msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue