feat: merge geekan:main

This commit is contained in:
莘权 马 2024-03-21 22:08:13 +08:00
commit c6b9e234bf
24 changed files with 326 additions and 90 deletions

2
.gitignore vendored
View file

@ -1,7 +1,7 @@
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
__pycache__
*.py[cod]
*$py.class

View file

@ -1,10 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/01/15
@Author : mannaandpoem
@File : imitate_webpage.py
"""
from metagpt.roles.di.data_interpreter import DataInterpreter

View file

@ -13,7 +13,7 @@ async def main():
question = "What are the most interesting human facts?"
search = Config.default().search
kwargs = {"api_key": search.api_key, "cse_id": search.cse_id, "proxy": None}
kwargs = search.model_dump()
await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question)

View file

@ -331,7 +331,7 @@ class ActionNode:
def compile_to(self, i: Dict, schema, kv_sep) -> str:
if schema == "json":
return json.dumps(i, indent=4)
return json.dumps(i, indent=4, ensure_ascii=False)
elif schema == "markdown":
return dict_to_markdown(i, kv_sep=kv_sep)
else:

View file

@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import (
STRUCTUAL_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.utils.common import CodeParser, process_message, remove_comments
from metagpt.utils.common import CodeParser, remove_comments
class WriteAnalysisCode(Action):
@ -50,7 +50,7 @@ class WriteAnalysisCode(Action):
)
working_memory = working_memory or []
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)
# LLM call
if use_reflection:

View file

@ -7,6 +7,8 @@
"""
from typing import Callable, Optional
from pydantic import Field
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -18,3 +20,11 @@ class SearchConfig(YamlModel):
api_key: str = ""
cse_id: str = "" # for google
search_func: Optional[Callable] = None
params: dict = Field(
default_factory=lambda: {
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
}
)

View file

@ -9,11 +9,11 @@
from pathlib import Path
from typing import Dict, List, Optional
import aiofiles
import yaml
from pydantic import BaseModel, Field
from metagpt.context import Context
from metagpt.utils.common import aread
class Example(BaseModel):
@ -68,8 +68,7 @@ class SkillsDeclaration(BaseModel):
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
if not skill_yaml_file_name:
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
data = await reader.read(-1)
data = await aread(filename=skill_yaml_file_name)
skill_data = yaml.safe_load(data)
return SkillsDeclaration(**skill_data)

View file

@ -74,6 +74,28 @@ class BaseLLM(ABC):
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
if not isinstance(messages, list):
messages = [messages]
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
return [self._system_msg(msg) for msg in msgs]

View file

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
import os
from typing import Optional, Union
import google.generativeai as genai
@ -16,9 +17,10 @@ from google.generativeai.types.generation_types import (
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
class GeminiGenerativeModel(GenerativeModel):
@ -52,6 +54,10 @@ class GeminiLLM(BaseLLM):
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):
if config.proxy:
logger.info(f"Use proxy: {config.proxy}")
os.environ["HTTP_PROXY"] = config.proxy
os.environ["HTTP_PROXYS"] = config.proxy
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
@ -62,6 +68,35 @@ class GeminiLLM(BaseLLM):
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "model", "parts": [msg]}
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "parts": [msg]}
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
if not isinstance(messages, list):
messages = [messages]
# REF: https://ai.google.dev/tutorials/python_quickstart
# As a dictionary, the message requires `role` and `parts` keys.
# The role in a conversation can either be the `user`, which provides the prompts,
# or `model`, which provides the responses.
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "parts": [msg]})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "parts"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs

View file

@ -30,12 +30,7 @@ from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import (
CodeParser,
decode_image,
log_and_reraise,
process_message,
)
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
@ -151,7 +146,7 @@ class OpenAILLM(BaseLLM):
async def _achat_completion_function(
self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, **chat_configs
) -> ChatCompletion:
messages = process_message(messages)
messages = self.format_msg(messages)
kwargs = self._cons_kwargs(messages=messages, timeout=self.get_timeout(timeout), **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)

View file

@ -29,6 +29,7 @@ from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import chardet
import loguru
import requests
from PIL import Image
@ -663,14 +664,21 @@ def role_raise_decorator(func):
@handle_exception
async def aread(filename: str | Path, encoding=None) -> str:
async def aread(filename: str | Path, encoding="utf-8") -> str:
"""Read file asynchronously."""
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
try:
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
content = await reader.read()
except UnicodeDecodeError:
async with aiofiles.open(str(filename), mode="rb") as reader:
raw = await reader.read()
result = chardet.detect(raw)
detected_encoding = result["encoding"]
content = raw.decode(detected_encoding)
return content
async def awrite(filename: str | Path, data: str, encoding=None):
async def awrite(filename: str | Path, data: str, encoding="utf-8"):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
@ -802,29 +810,6 @@ def decode_image(img_url_or_b64: str) -> Image:
return img
def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
# 全部转成list
if not isinstance(messages, list):
messages = [messages]
# 转成list[dict]
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
return processed_messages
def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(

View file

@ -13,9 +13,7 @@ import re
from pathlib import Path
from typing import Set
import aiofiles
from metagpt.utils.common import aread
from metagpt.utils.common import aread, awrite
from metagpt.utils.exceptions import handle_exception
@ -45,8 +43,7 @@ class DependencyFile:
async def save(self):
"""Save dependencies to the file asynchronously."""
data = json.dumps(self._dependencies)
async with aiofiles.open(str(self._filename), mode="w") as writer:
await writer.write(data)
await awrite(filename=self._filename, data=data)
async def update(self, filename: Path | str, dependencies: Set[Path | str], persist=True):
"""Update dependencies for a file asynchronously.

View file

@ -14,11 +14,9 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set
import aiofiles
from metagpt.logs import logger
from metagpt.schema import Document
from metagpt.utils.common import aread
from metagpt.utils.common import aread, awrite
from metagpt.utils.json_to_markdown import json_to_markdown
@ -55,8 +53,7 @@ class FileRepository:
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
await awrite(filename=str(pathname), data=content)
logger.info(f"save to: {str(pathname)}")
if dependencies is not None:

View file

@ -9,11 +9,9 @@ import asyncio
import os
from pathlib import Path
import aiofiles
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.common import check_cmd_exists
from metagpt.utils.common import awrite, check_cmd_exists
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
@ -30,9 +28,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
tmp = Path(f"{output_file_without_suffix}.mmd")
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
await f.write(mermaid_code)
# tmp.write_text(mermaid_code, encoding="utf-8")
await awrite(filename=tmp, data=mermaid_code)
if engine == "nodejs":
if check_cmd_exists(config.mermaid.path) != 0:

View file

@ -340,7 +340,9 @@ def extract_state_value_from_output(content: str) -> str:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
pattern = (
r"(?<!-)[0-9]" # TODO find the number using a more proper method not just extract from content using pattern
)
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"

140
metagpt/utils/tree.py Normal file
View file

@ -0,0 +1,140 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/3/11
@Author : mashenquan
@File : tree.py
@Desc : Implement the same functionality as the `tree` command.
Example:
>>> print_tree(".")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- cost_manager.py
+-- __pycache__
| +-- __init__.cpython-39.pyc
| +-- redis.cpython-39.pyc
| +-- singleton.cpython-39.pyc
| +-- embedding.cpython-39.pyc
| +-- make_sk_kernel.cpython-39.pyc
| +-- file_repository.cpython-39.pyc
+-- file.py
+-- save_code.py
+-- common.py
+-- redis.py
"""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Callable, Dict, List
from gitignore_parser import parse_gitignore
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
"""
Recursively traverses the directory structure and prints it out in a tree-like format.
Args:
root (str or Path): The root directory from which to start traversing.
gitignore (str or Path): The filename of gitignore file.
run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True,
otherwise execute python code instead.
Returns:
str: A string representation of the directory tree.
Example:
>>> tree(".")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- __pycache__
| +-- __init__.cpython-39.pyc
| +-- redis.cpython-39.pyc
| +-- singleton.cpython-39.pyc
+-- parse_docstring.py
>>> tree(".", gitignore="../../.gitignore")
utils
+-- serialize.py
+-- project_repo.py
+-- tree.py
+-- mmdc_playwright.py
+-- parse_docstring.py
>>> tree(".", gitignore="../../.gitignore", run_command=True)
utils
serialize.py
project_repo.py
tree.py
mmdc_playwright.py
parse_docstring.py
"""
root = Path(root).resolve()
if run_command:
return _execute_tree(root, gitignore)
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
v = _print_tree(dir_)
return "\n".join(v)
def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]:
dir_ = {}
for i in root.iterdir():
if git_ignore_rules and git_ignore_rules(str(i)):
continue
try:
if i.is_file():
dir_[i.name] = {}
else:
dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules)
except (FileNotFoundError, PermissionError, OSError):
dir_[i.name] = {}
return dir_
def _print_tree(dir_: Dict[str:Dict]) -> List[str]:
ret = []
for name, children in dir_.items():
ret.append(name)
if not children:
continue
lines = _print_tree(children)
for j, v in enumerate(lines):
if v[0] not in ["+", " ", "|"]:
ret = _add_line(ret)
row = f"+-- {v}"
else:
row = f" {v}"
ret.append(row)
return ret
def _add_line(rows: List[str]) -> List[str]:
for i in range(len(rows) - 1, -1, -1):
v = rows[i]
if v[0] != " ":
return rows
rows[i] = "|" + v[1:]
return rows
def _execute_tree(root: Path, gitignore: str | Path) -> str:
args = ["--gitfile", str(gitignore)] if gitignore else []
try:
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
if result.returncode != 0:
raise ValueError(f"tree exits with code {result.returncode}")
return result.stdout
except subprocess.CalledProcessError as e:
raise e

View file

@ -57,7 +57,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
setup(
name="metagpt",
version="0.7.4",
version="0.7.6",
description="The Multi-Agent Framework",
long_description=long_description,
long_description_content_type="text/markdown",

View file

@ -6,11 +6,11 @@
@File : test_tutorial_assistant.py
"""
import aiofiles
import pytest
from metagpt.const import TUTORIAL_PATH
from metagpt.roles.tutorial_assistant import TutorialAssistant
from metagpt.utils.common import aread
@pytest.mark.asyncio
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
msg = await role.run(topic)
assert TUTORIAL_PATH.exists()
filename = msg.content
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
content = await reader.read()
assert "pip" in content
content = await aread(filename=filename)
assert "pip" in content
if __name__ == "__main__":

View file

@ -11,7 +11,6 @@ from typing import Callable
import pytest
from metagpt.config2 import config
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
@ -53,14 +52,11 @@ async def test_search_engine(
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-google-key"
search_engine_config["cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
search_engine_config["api_key"] = "mock-serper-key"
async def test(search_engine):

View file

@ -13,7 +13,6 @@ import uuid
from pathlib import Path
from typing import Any, Set
import aiofiles
import pytest
from pydantic import BaseModel
@ -125,9 +124,7 @@ class TestGetProjectRoot:
async def test_parse_data_exception(self, filename, want):
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
assert pathname.exists()
async with aiofiles.open(str(pathname), mode="r") as reader:
data = await reader.read()
data = await aread(filename=pathname)
result = OutputParser.parse_data(data=data)
assert want in result
@ -198,12 +195,25 @@ class TestGetProjectRoot:
@pytest.mark.asyncio
async def test_read_write(self):
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
await awrite(pathname, "ABC")
data = await aread(pathname)
assert data == "ABC"
pathname.unlink(missing_ok=True)
@pytest.mark.asyncio
async def test_read_write_error_charset(self):
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
content = "中国abc123\u27f6"
await awrite(filename=pathname, data=content)
data = await aread(filename=pathname)
assert data == content
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
await awrite(filename=pathname, data=content, encoding="gb2312")
data = await aread(filename=pathname, encoding="utf-8")
assert data == content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,15 +10,14 @@
import shutil
from pathlib import Path
import aiofiles
import pytest
from metagpt.utils.common import awrite
from metagpt.utils.git_repository import GitRepository
async def mock_file(filename, content=""):
async with aiofiles.open(str(filename), mode="w") as file:
await file.write(content)
await awrite(filename=filename, data=content)
async def mock_repo(local_path) -> (GitRepository, Path):

View file

@ -9,7 +9,6 @@ import uuid
from pathlib import Path
import aioboto3
import aiofiles
import pytest
from metagpt.config2 import Config
@ -46,7 +45,7 @@ async def test_s3(mocker):
conn = S3(s3)
object_name = "unittest.bak"
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
pathname.unlink(missing_ok=True)
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
assert pathname.exists()
@ -54,8 +53,7 @@ async def test_s3(mocker):
assert url
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
assert bin_data
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
data = await reader.read()
data = await aread(filename=__file__)
res = await conn.cache(data, ".bak", "script")
assert "http" in res
@ -69,8 +67,6 @@ async def test_s3(mocker):
except Exception:
pass
await reader.close()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,64 @@
from pathlib import Path
from typing import List
import pytest
from metagpt.utils.tree import _print_tree, tree
@pytest.mark.parametrize(
("root", "rules"),
[
(str(Path(__file__).parent / "../.."), None),
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
],
)
def test_tree(root: str, rules: str):
v = tree(root=root, gitignore=rules)
assert v
@pytest.mark.parametrize(
("root", "rules"),
[
(str(Path(__file__).parent / "../.."), None),
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
],
)
def test_tree_command(root: str, rules: str):
v = tree(root=root, gitignore=rules, run_command=True)
assert v
@pytest.mark.parametrize(
("tree", "want"),
[
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
(
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
),
(
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
[
"h",
"+-- a",
"| +-- b",
"| | +-- e",
"| | +-- f",
"| | +-- g",
"| +-- c",
"| +-- d",
"+-- i",
],
),
],
)
def test__print_tree(tree: dict, want: List[str]):
v = _print_tree(tree)
assert v == want
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -8,7 +8,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
from metagpt.utils.common import process_message
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
@ -105,7 +104,7 @@ class MockLLM(OriginalLLM):
return rsp
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
msg_key = json.dumps(process_message(messages), ensure_ascii=False)
msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp