+添加运营小姐姐,拉你入群
-如果群已满,请添加负责人微信,会邀请进群
-
-
\ No newline at end of file
+
diff --git a/docs/README_JA.md b/docs/README_JA.md
index a5e5f6552..57f6487a7 100644
--- a/docs/README_JA.md
+++ b/docs/README_JA.md
@@ -75,25 +75,25 @@ ### Docker によるインストール
```bash
# ステップ 1: metagpt 公式イメージをダウンロードし、config.yaml を準備する
-docker pull metagpt/metagpt:v0.3
+docker pull metagpt/metagpt:v0.3.1
mkdir -p /opt/metagpt/{config,workspace}
-docker run --rm metagpt/metagpt:v0.3 cat /app/metagpt/config/config.yaml > /opt/metagpt/config/config.yaml
-vim /opt/metagpt/config/config.yaml # 設定を変更する
+docker run --rm metagpt/metagpt:v0.3.1 cat /app/metagpt/config/config.yaml > /opt/metagpt/config/key.yaml
+vim /opt/metagpt/config/key.yaml # 設定を変更する
# ステップ 2: コンテナで metagpt デモを実行する
docker run --rm \
--privileged \
- -v /opt/metagpt/config:/app/metagpt/config \
+ -v /opt/metagpt/config/key.yaml:/app/metagpt/config/key.yaml \
-v /opt/metagpt/workspace:/app/metagpt/workspace \
- metagpt/metagpt:v0.3 \
+ metagpt/metagpt:v0.3.1 \
python startup.py "Write a cli snake game"
# コンテナを起動し、その中でコマンドを実行することもできます
docker run --name metagpt -d \
--privileged \
- -v /opt/metagpt/config:/app/metagpt/config \
+ -v /opt/metagpt/config/key.yaml:/app/metagpt/config/key.yaml \
-v /opt/metagpt/workspace:/app/metagpt/workspace \
- metagpt/metagpt:v0.3
+ metagpt/metagpt:v0.3.1
docker exec -it metagpt /bin/bash
$ python startup.py "Write a cli snake game"
@@ -111,7 +111,7 @@ ### 自分でイメージをビルドする
```bash
# また、自分で metagpt イメージを構築することもできます。
git clone https://github.com/geekan/MetaGPT.git
-cd MetaGPT && docker build -t metagpt:v0.3 .
+cd MetaGPT && docker build -t metagpt:custom .
```
## 設定
@@ -142,37 +142,36 @@ ### プラットフォームまたはツールの設定
要件を述べるときに、どのプラットフォームまたはツールを使用するかを指定できます。
```shell
-python startup.py "Write a cli snake game based on pygame"
+python startup.py "pygame をベースとした cli ヘビゲームを書く"
```
-
### 使用方法
```
-NAME
- startup.py - We are a software startup comprised of AI. By investing in us, you are empowering a future filled with limitless possibilities.
+会社名
+ startup.py - 私たちは AI で構成されたソフトウェア・スタートアップです。私たちに投資することは、無限の可能性に満ちた未来に力を与えることです。
-SYNOPSIS
+シノプシス
startup.py IDEA .*)(```.*?)',
+ r'(.*?```python.*?\s+)?(?P.*)',
+ ):
+ match = re.search(pattern, text, re.DOTALL)
+ if not match:
+ continue
+ code = match.group("code")
+ if not code:
+ continue
+ with contextlib.suppress(Exception):
+ ast.parse(code)
+ return code
+ raise ValueError("Invalid python code")
@classmethod
def parse_data(cls, data):
@@ -183,7 +201,7 @@ class CodeParser:
def parse_file_list(cls, block: str, text: str, lang: str = "") -> list[str]:
# Regular expression pattern to find the tasks list.
code = cls.parse_code(block, text, lang)
- print(code)
+ # print(code)
pattern = r'\s*(.*=.*)?(\[.*\])'
# Extract tasks list string using regex.
@@ -230,3 +248,9 @@ def print_members(module, indent=0):
print(f'{prefix}Function: {name}')
elif inspect.ismethod(obj):
print(f'{prefix}Method: {name}')
+
+
+def parse_recipient(text):
+ pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
+ recipient = re.search(pattern, text)
+ return recipient.group(1) if recipient else ""
diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py
index 3788b4743..24aabe8ae 100644
--- a/metagpt/utils/mermaid.py
+++ b/metagpt/utils/mermaid.py
@@ -5,9 +5,9 @@
@Author : alexanderwu
@File : mermaid.py
"""
-import os
import subprocess
from pathlib import Path
+
from metagpt.config import CONFIG
from metagpt.const import PROJECT_ROOT
from metagpt.logs import logger
@@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height
:return: 0 if succed, -1 if failed
"""
# Write the Mermaid code to a temporary file
- tmp = Path(f'{output_file_without_suffix}.mmd')
- tmp.write_text(mermaid_code, encoding='utf-8')
+ tmp = Path(f"{output_file_without_suffix}.mmd")
+ tmp.write_text(mermaid_code, encoding="utf-8")
- if check_cmd_exists('mmdc') != 0:
- logger.warning(
- "RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
+ if check_cmd_exists("mmdc") != 0:
+ logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
return -1
- for suffix in ['pdf', 'svg', 'png']:
- output_file = f'{output_file_without_suffix}.{suffix}'
+ for suffix in ["pdf", "svg", "png"]:
+ output_file = f"{output_file_without_suffix}.{suffix}"
# Call the `mmdc` command to convert the Mermaid code to a PNG
logger.info(f"Generating {output_file}..")
if CONFIG.puppeteer_config:
- subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o',
- output_file, '-w', str(width), '-H', str(height)])
+ subprocess.run(
+ [
+ CONFIG.mmdc,
+ "-p",
+ CONFIG.puppeteer_config,
+ "-i",
+ str(tmp),
+ "-o",
+ output_file,
+ "-w",
+ str(width),
+ "-H",
+ str(height),
+ ]
+ )
else:
- subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o',
- output_file, '-w', str(width), '-H', str(height)])
+ subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)])
return 0
@@ -97,7 +108,7 @@ MMC2 = """sequenceDiagram
SE-->>M: return summary"""
-if __name__ == '__main__':
+if __name__ == "__main__":
# logger.info(print_members(print_members))
- mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png')
- mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png')
+ mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png")
+ mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png")
diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py
new file mode 100644
index 000000000..62de26541
--- /dev/null
+++ b/metagpt/utils/parse_html.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+from typing import Generator, Optional
+from urllib.parse import urljoin, urlparse
+
+from bs4 import BeautifulSoup
+from pydantic import BaseModel
+
+
+class WebPage(BaseModel):
+ inner_text: str
+ html: str
+ url: str
+
+ class Config:
+ underscore_attrs_are_private = True
+
+ _soup : Optional[BeautifulSoup] = None
+ _title: Optional[str] = None
+
+ @property
+ def soup(self) -> BeautifulSoup:
+ if self._soup is None:
+ self._soup = BeautifulSoup(self.html, "html.parser")
+ return self._soup
+
+ @property
+ def title(self):
+ if self._title is None:
+ title_tag = self.soup.find("title")
+ self._title = title_tag.text.strip() if title_tag is not None else ""
+ return self._title
+
+ def get_links(self) -> Generator[str, None, None]:
+ for i in self.soup.find_all("a", href=True):
+ url = i["href"]
+ result = urlparse(url)
+ if not result.scheme and result.path:
+ yield urljoin(self.url, url)
+ elif url.startswith(("http://", "https://")):
+ yield urljoin(self.url, url)
+
+
+def get_html_content(page: str, base: str):
+ soup = _get_soup(page)
+
+ return soup.get_text(strip=True)
+
+
+def _get_soup(page: str):
+ soup = BeautifulSoup(page, "html.parser")
+ # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
+ for s in soup(["style", "script", "[document]", "head", "title"]):
+ s.extract()
+
+ return soup
diff --git a/metagpt/utils/pycst.py b/metagpt/utils/pycst.py
new file mode 100644
index 000000000..afd85a547
--- /dev/null
+++ b/metagpt/utils/pycst.py
@@ -0,0 +1,166 @@
+from __future__ import annotations
+
+from typing import Union
+
+import libcst as cst
+from libcst._nodes.module import Module
+
+DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef]
+
+
+def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
+ """Extracts the docstring from the body of a node.
+
+ Args:
+ body: The body of a node.
+
+ Returns:
+ The docstring statement if it exists, None otherwise.
+ """
+ if isinstance(body, cst.Module):
+ body = body.body
+ else:
+ body = body.body.body
+
+ if not body:
+ return
+
+ statement = body[0]
+ if not isinstance(statement, cst.SimpleStatementLine):
+ return
+
+ expr = statement
+ while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)):
+ if len(expr.body) == 0:
+ return None
+ expr = expr.body[0]
+
+ if not isinstance(expr, cst.Expr):
+ return None
+
+ val = expr.value
+ if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)):
+ return None
+
+ evaluated_value = val.evaluated_value
+ if isinstance(evaluated_value, bytes):
+ return None
+
+ return statement
+
+
+class DocstringCollector(cst.CSTVisitor):
+ """A visitor class for collecting docstrings from a CST.
+
+ Attributes:
+ stack: A list to keep track of the current path in the CST.
+ docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
+ """
+ def __init__(self):
+ self.stack: list[str] = []
+ self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {}
+
+ def visit_Module(self, node: cst.Module) -> bool | None:
+ self.stack.append("")
+
+ def leave_Module(self, node: cst.Module) -> None:
+ return self._leave(node)
+
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
+ self.stack.append(node.name.value)
+
+ def leave_ClassDef(self, node: cst.ClassDef) -> None:
+ return self._leave(node)
+
+ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
+ self.stack.append(node.name.value)
+
+ def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
+ return self._leave(node)
+
+ def _leave(self, node: DocstringNode) -> None:
+ key = tuple(self.stack)
+ self.stack.pop()
+ if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
+ return
+
+ statement = get_docstring_statement(node)
+ if statement:
+ self.docstrings[key] = statement
+
+
+class DocstringTransformer(cst.CSTTransformer):
+ """A transformer class for replacing docstrings in a CST.
+
+ Attributes:
+ stack: A list to keep track of the current path in the CST.
+ docstrings: A dictionary mapping paths in the CST to their corresponding docstrings.
+ """
+ def __init__(
+ self,
+ docstrings: dict[tuple[str, ...], cst.SimpleStatementLine],
+ ):
+ self.stack: list[str] = []
+ self.docstrings = docstrings
+
+ def visit_Module(self, node: cst.Module) -> bool | None:
+ self.stack.append("")
+
+ def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
+ return self._leave(original_node, updated_node)
+
+ def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
+ self.stack.append(node.name.value)
+
+ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode:
+ return self._leave(original_node, updated_node)
+
+ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
+ self.stack.append(node.name.value)
+
+ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode:
+ return self._leave(original_node, updated_node)
+
+ def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode:
+ key = tuple(self.stack)
+ self.stack.pop()
+
+ if hasattr(updated_node, "decorators") and any((i.decorator.value == "overload") for i in updated_node.decorators):
+ return updated_node
+
+ statement = self.docstrings.get(key)
+ if not statement:
+ return updated_node
+
+ original_statement = get_docstring_statement(original_node)
+
+ if isinstance(updated_node, cst.Module):
+ body = updated_node.body
+ if original_statement:
+ return updated_node.with_changes(body=(statement, *body[1:]))
+ else:
+ updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body))
+ return updated_node
+
+ body = updated_node.body.body[1:] if original_statement else updated_node.body.body
+ return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body)))
+
+
+def merge_docstring(code: str, documented_code: str) -> str:
+ """Merges the docstrings from the documented code into the original code.
+
+ Args:
+ code: The original code.
+ documented_code: The documented code.
+
+ Returns:
+ The original code with the docstrings from the documented code.
+ """
+ code_tree = cst.parse_module(code)
+ documented_code_tree = cst.parse_module(documented_code)
+
+ visitor = DocstringCollector()
+ documented_code_tree.visit(visitor)
+ transformer = DocstringTransformer(visitor.docstrings)
+ modified_tree = code_tree.visit(transformer)
+ return modified_tree.code
diff --git a/metagpt/utils/serialize.py b/metagpt/utils/serialize.py
index 34dee7098..ffafca8cd 100644
--- a/metagpt/utils/serialize.py
+++ b/metagpt/utils/serialize.py
@@ -3,14 +3,11 @@
# @Desc : the implement of serialization and deserialization
import copy
-from typing import Tuple, List, Type, Union, Dict
import pickle
-from collections import defaultdict
-from pydantic import create_model
+from typing import Dict, List, Tuple
-from metagpt.schema import Message
-from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
+from metagpt.schema import Message
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
@@ -34,12 +31,12 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
```
"""
mapping = dict()
- for field, property in schema['properties'].items():
- if property['type'] == 'string':
+ for field, property in schema["properties"].items():
+ if property["type"] == "string":
mapping[field] = (str, ...)
- elif property['type'] == 'array' and property['items']['type'] == 'string':
+ elif property["type"] == "array" and property["items"]["type"] == "string":
mapping[field] = (List[str], ...)
- elif property['type'] == 'array' and property['items']['type'] == 'array':
+ elif property["type"] == "array" and property["items"]["type"] == "array":
# here only consider the `Tuple[str, str]` situation
mapping[field] = (List[Tuple[str, str]], ...)
return mapping
@@ -53,11 +50,7 @@ def serialize_message(message: Message):
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
- message_cp.instruct_content = {
- 'class': schema['title'],
- 'mapping': mapping,
- 'value': ic.dict()
- }
+ message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
msg_ser = pickle.dumps(message_cp)
return msg_ser
@@ -67,9 +60,8 @@ def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
- ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
- mapping=ic['mapping'])
- ic_new = ic_obj(**ic['value'])
+ ic_obj = ActionOutput.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
+ ic_new = ic_obj(**ic["value"])
message.instruct_content = ic_new
return message
diff --git a/metagpt/utils/special_tokens.py b/metagpt/utils/special_tokens.py
new file mode 100644
index 000000000..2adb93c77
--- /dev/null
+++ b/metagpt/utils/special_tokens.py
@@ -0,0 +1,4 @@
+# token to separate different code messages in a WriteCode Message content
+MSG_SEP = "#*000*#"
+# token to seperate file name and the actual code text in a code message
+FILENAME_CODE_SEP = "#*001*#"
diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py
new file mode 100644
index 000000000..be3c52edd
--- /dev/null
+++ b/metagpt/utils/text.py
@@ -0,0 +1,124 @@
+from typing import Generator, Sequence
+
+from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
+
+
+def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
+ """Reduce the length of concatenated message segments to fit within the maximum token size.
+
+ Args:
+ msgs: A generator of strings representing progressively shorter valid prompts.
+ model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
+ system_text: The system prompts.
+ reserved: The number of reserved tokens.
+
+ Returns:
+ The concatenated message segments reduced to fit within the maximum token size.
+
+ Raises:
+ RuntimeError: If it fails to reduce the concatenated message length.
+ """
+ max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
+ for msg in msgs:
+ if count_string_tokens(msg, model_name) < max_token:
+ return msg
+
+ raise RuntimeError("fail to reduce message length")
+
+
+def generate_prompt_chunk(
+ text: str,
+ prompt_template: str,
+ model_name: str,
+ system_text: str,
+ reserved: int = 0,
+) -> Generator[str, None, None]:
+ """Split the text into chunks of a maximum token size.
+
+ Args:
+ text: The text to split.
+ prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
+ model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
+ system_text: The system prompts.
+ reserved: The number of reserved tokens.
+
+ Yields:
+ The chunk of text.
+ """
+ paragraphs = text.splitlines(keepends=True)
+ current_token = 0
+ current_lines = []
+
+ reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
+ # 100 is a magic number to ensure the maximum context length is not exceeded
+ max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
+
+ while paragraphs:
+ paragraph = paragraphs.pop(0)
+ token = count_string_tokens(paragraph, model_name)
+ if current_token + token <= max_token:
+ current_lines.append(paragraph)
+ current_token += token
+ elif token > max_token:
+ paragraphs = split_paragraph(paragraph) + paragraphs
+ continue
+ else:
+ yield prompt_template.format("".join(current_lines))
+ current_lines = [paragraph]
+ current_token = token
+
+ if current_lines:
+ yield prompt_template.format("".join(current_lines))
+
+
+def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
+ """Split a paragraph into multiple parts.
+
+ Args:
+ paragraph: The paragraph to split.
+ sep: The separator character.
+ count: The number of parts to split the paragraph into.
+
+ Returns:
+ A list of split parts of the paragraph.
+ """
+ for i in sep:
+ sentences = list(_split_text_with_ends(paragraph, i))
+ if len(sentences) <= 1:
+ continue
+ ret = ["".join(j) for j in _split_by_count(sentences, count)]
+ return ret
+ return _split_by_count(paragraph, count)
+
+
+def decode_unicode_escape(text: str) -> str:
+ """Decode a text with unicode escape sequences.
+
+ Args:
+ text: The text to decode.
+
+ Returns:
+ The decoded text.
+ """
+ return text.encode("utf-8").decode("unicode_escape", "ignore")
+
+
+def _split_by_count(lst: Sequence , count: int):
+ avg = len(lst) // count
+ remainder = len(lst) % count
+ start = 0
+ for i in range(count):
+ end = start + avg + (1 if i < remainder else 0)
+ yield lst[start:end]
+ start = end
+
+
+def _split_text_with_ends(text: str, sep: str = "."):
+ parts = []
+ for i in text:
+ parts.append(i)
+ if i == sep:
+ yield "".join(parts)
+ parts = []
+ if parts:
+ yield "".join(parts)
diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py
index 99ae5e176..591bb60f0 100644
--- a/metagpt/utils/token_counter.py
+++ b/metagpt/utils/token_counter.py
@@ -25,6 +25,21 @@ TOKEN_COSTS = {
}
+TOKEN_MAX = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0301": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-3.5-turbo-16k-0613": 16384,
+ "gpt-4-0314": 8192,
+ "gpt-4": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-4-32k-0314": 32768,
+ "gpt-4-0613": 8192,
+ "text-embedding-ada-002": 8192,
+}
+
+
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
@@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
- }:
+ }:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
@@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))
+
+
+def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
+ """Calculate the maximum number of completion tokens for a given model and list of messages.
+
+ Args:
+ messages: A list of messages.
+ model: The model name.
+
+ Returns:
+ The maximum number of completion tokens.
+ """
+ if model not in TOKEN_MAX:
+ return default
+ return TOKEN_MAX[model] - count_message_tokens(messages)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 000000000..ed7c2769e
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,69 @@
+[project]
+name = "chatgit"
+version = "0.1.0"
+description = "chatgit is an LLM model-based open source project competition analysis research project, it can help you find the most suitable open source project for your needs"
+authors = [
+ {name = "hezz", email = "hezhaozhaog@gmail.com"},
+]
+dependencies = [
+ "requests>=2.31.0",
+]
+requires-python = ">=3.11"
+readme = "README.md"
+license = {text = "Apache"}
+
+[build-system]
+requires = ["setuptools>=61", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.black]
+line-length = 119
+target-version = ['py39']
+
+
+[tool.ruff]
+# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
+select = ["E", "F"]
+ignore = ["E501", "E712", "E722", "F821", "E731"]
+
+# Allow autofix for all enabled rules (when `--fix`) is provided.
+fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
+unfixable = []
+
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+]
+
+# Same as Black.
+line-length = 119
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+# Assume Python 3.9
+target-version = "py39"
+
+[tool.ruff.mccabe]
+# Unlike Flake8, default to a complexity level of 10.
+max-complexity = 10
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 32a436962..452e2d092 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -35,3 +35,4 @@ tqdm==4.64.0
anthropic==0.3.6
typing-inspect==0.8.0
typing_extensions==4.5.0
+libcst==1.0.1
diff --git a/setup.py b/setup.py
index e65696901..2a8edaae7 100644
--- a/setup.py
+++ b/setup.py
@@ -44,7 +44,7 @@ setup(
install_requires=requirements,
extras_require={
"playwright": ["playwright>=1.26", "beautifulsoup4"],
- "selenium": ["selenium>4", "webdriver_manager<3.9", "beautifulsoup4"],
+ "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
},
cmdclass={
"install_mermaid": InstallMermaidCLI,
diff --git a/startup.py b/startup.py
index e062babb5..f37b5286c 100644
--- a/startup.py
+++ b/startup.py
@@ -4,23 +4,27 @@ import asyncio
import fire
-from metagpt.roles import Architect, Engineer, ProductManager, ProjectManager
+from metagpt.roles import Architect, Engineer, ProductManager, ProjectManager, QaEngineer
from metagpt.software_company import SoftwareCompany
-async def startup(idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool = False):
+async def startup(idea: str, investment: float = 3.0, n_round: int = 5,
+ code_review: bool = False, run_tests: bool = False):
"""Run a startup. Be a boss."""
company = SoftwareCompany()
company.hire([ProductManager(),
Architect(),
ProjectManager(),
Engineer(n_borg=5, use_code_review=code_review)])
+ if run_tests:
+ # developing features: run tests on the spot and identify bugs (bug fixing capability comes soon!)
+ company.hire([QaEngineer()])
company.invest(investment)
company.start_project(idea)
await company.run(n_round=n_round)
-def main(idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool = False):
+def main(idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool = False, run_tests: bool = False):
"""
We are a software startup comprised of AI. By investing in us, you are empowering a future filled with limitless possibilities.
:param idea: Your innovative idea, such as "Creating a snake game."
@@ -29,7 +33,7 @@ def main(idea: str, investment: float = 3.0, n_round: int = 5, code_review: bool
:param code_review: Whether to use code review.
:return:
"""
- asyncio.run(startup(idea, investment, n_round, code_review))
+ asyncio.run(startup(idea, investment, n_round, code_review, run_tests))
if __name__ == '__main__':
diff --git a/tests/metagpt/actions/test_debug_error.py b/tests/metagpt/actions/test_debug_error.py
index 526fd548f..555c84e4e 100644
--- a/tests/metagpt/actions/test_debug_error.py
+++ b/tests/metagpt/actions/test_debug_error.py
@@ -9,15 +9,147 @@ import pytest
from metagpt.actions.debug_error import DebugError
+EXAMPLE_MSG_CONTENT = '''
+---
+## Development Code File Name
+player.py
+## Development Code
+```python
+from typing import List
+from deck import Deck
+from card import Card
+
+class Player:
+ """
+ A class representing a player in the Black Jack game.
+ """
+
+ def __init__(self, name: str):
+ """
+ Initialize a Player object.
+
+ Args:
+ name (str): The name of the player.
+ """
+ self.name = name
+ self.hand: List[Card] = []
+ self.score = 0
+
+ def draw(self, deck: Deck):
+ """
+ Draw a card from the deck and add it to the player's hand.
+
+ Args:
+ deck (Deck): The deck of cards.
+ """
+ card = deck.draw_card()
+ self.hand.append(card)
+ self.calculate_score()
+
+ def calculate_score(self) -> int:
+ """
+ Calculate the score of the player's hand.
+
+ Returns:
+ int: The score of the player's hand.
+ """
+ self.score = sum(card.value for card in self.hand)
+ # Handle the case where Ace is counted as 11 and causes the score to exceed 21
+ if self.score > 21 and any(card.rank == 'A' for card in self.hand):
+ self.score -= 10
+ return self.score
+
+```
+## Test File Name
+test_player.py
+## Test Code
+```python
+import unittest
+from blackjack_game.player import Player
+from blackjack_game.deck import Deck
+from blackjack_game.card import Card
+
+class TestPlayer(unittest.TestCase):
+ ## Test the Player's initialization
+ def test_player_initialization(self):
+ player = Player("Test Player")
+ self.assertEqual(player.name, "Test Player")
+ self.assertEqual(player.hand, [])
+ self.assertEqual(player.score, 0)
+
+ ## Test the Player's draw method
+ def test_player_draw(self):
+ deck = Deck()
+ player = Player("Test Player")
+ player.draw(deck)
+ self.assertEqual(len(player.hand), 1)
+ self.assertEqual(player.score, player.hand[0].value)
+
+ ## Test the Player's calculate_score method
+ def test_player_calculate_score(self):
+ deck = Deck()
+ player = Player("Test Player")
+ player.draw(deck)
+ player.draw(deck)
+ self.assertEqual(player.score, sum(card.value for card in player.hand))
+
+ ## Test the Player's calculate_score method with Ace card
+ def test_player_calculate_score_with_ace(self):
+ deck = Deck()
+ player = Player("Test Player")
+ player.hand.append(Card('A', 'Hearts', 11))
+ player.hand.append(Card('K', 'Hearts', 10))
+ player.calculate_score()
+ self.assertEqual(player.score, 21)
+
+ ## Test the Player's calculate_score method with multiple Aces
+ def test_player_calculate_score_with_multiple_aces(self):
+ deck = Deck()
+ player = Player("Test Player")
+ player.hand.append(Card('A', 'Hearts', 11))
+ player.hand.append(Card('A', 'Diamonds', 11))
+ player.calculate_score()
+ self.assertEqual(player.score, 12)
+
+if __name__ == '__main__':
+ unittest.main()
+
+```
+## Running Command
+python tests/test_player.py
+## Running Output
+standard output: ;
+standard errors: ..F..
+======================================================================
+FAIL: test_player_calculate_score_with_multiple_aces (__main__.TestPlayer)
+----------------------------------------------------------------------
+Traceback (most recent call last):
+ File "tests/test_player.py", line 46, in test_player_calculate_score_with_multiple_aces
+ self.assertEqual(player.score, 12)
+AssertionError: 22 != 12
+
+----------------------------------------------------------------------
+Ran 5 tests in 0.007s
+
+FAILED (failures=1)
+;
+## instruction:
+The error is in the development code, specifically in the calculate_score method of the Player class. The method is not correctly handling the case where there are multiple Aces in the player's hand. The current implementation only subtracts 10 from the score once if the score is over 21 and there's an Ace in the hand. However, in the case of multiple Aces, it should subtract 10 for each Ace until the score is 21 or less.
+## File To Rewrite:
+player.py
+## Status:
+FAIL
+## Send To:
+Engineer
+---
+'''
@pytest.mark.asyncio
async def test_debug_error():
- code = "def add(a, b):\n return a - b"
- error = "AssertionError: Expected add(1, 1) to equal 2 but got 0"
debug_error = DebugError("debug_error")
- result = await debug_error.run(code, error)
+ file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
- # mock_llm.ask.assert_called_once_with(prompt)
- assert len(result) > 0
+ assert "class Player" in rewritten_code # rewrite the same class
+ assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
diff --git a/tests/metagpt/actions/test_run_code.py b/tests/metagpt/actions/test_run_code.py
index af7d914b8..1e451cb14 100644
--- a/tests/metagpt/actions/test_run_code.py
+++ b/tests/metagpt/actions/test_run_code.py
@@ -11,28 +11,61 @@ from metagpt.actions.run_code import RunCode
@pytest.mark.asyncio
-async def test_run_code():
- code = """
-def add(a, b):
- return a + b
-result = add(1, 2)
-"""
- run_code = RunCode("run_code")
+async def test_run_text():
+ result, errs = await RunCode.run_text("result = 1 + 1")
+ assert result == 2
+ assert errs == ""
- result = await run_code.run(code)
-
- assert result == 3
+ result, errs = await RunCode.run_text("result = 1 / 0")
+ assert result == ""
+ assert "ZeroDivisionError" in errs
@pytest.mark.asyncio
-async def test_run_code_with_error():
- code = """
-def add(a, b):
- return a + b
-result = add(1, '2')
-"""
- run_code = RunCode("run_code")
+async def test_run_script():
+ # Successful command
+ out, err = await RunCode.run_script(".", command=["echo", "Hello World"])
+ assert out.strip() == "Hello World"
+ assert err == ""
- result = await run_code.run(code)
+ # Unsuccessful command
+ out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"])
+ assert "ZeroDivisionError" in err
- assert "TypeError: unsupported operand type(s) for +" in result
+
+@pytest.mark.asyncio
+async def test_run():
+ action = RunCode()
+ result = await action.run(mode="text", code="print('Hello, World')")
+ assert "PASS" in result
+
+ result = await action.run(
+ mode="script",
+ code="echo 'Hello World'",
+ code_file_name="",
+ test_code="",
+ test_file_name="",
+ command=["echo", "Hello World"],
+ working_directory=".",
+ additional_python_paths=[],
+ )
+ assert "PASS" in result
+
+
+@pytest.mark.asyncio
+async def test_run_failure():
+ action = RunCode()
+ result = await action.run(mode="text", code="result = 1 / 0")
+ assert "FAIL" in result
+
+ result = await action.run(
+ mode="script",
+ code='python -c "print(1/0)"',
+ code_file_name="",
+ test_code="",
+ test_file_name="",
+ command=["python", "-c", "print(1/0)"],
+ working_directory=".",
+ additional_python_paths=[],
+ )
+ assert "FAIL" in result
diff --git a/tests/metagpt/actions/test_write_code_review.py b/tests/metagpt/actions/test_write_code_review.py
index cee7eb941..21bc563ec 100644
--- a/tests/metagpt/actions/test_write_code_review.py
+++ b/tests/metagpt/actions/test_write_code_review.py
@@ -8,8 +8,6 @@
import pytest
from metagpt.actions.write_code_review import WriteCodeReview
-from metagpt.logs import logger
-from tests.metagpt.actions.mock import SEARCH_CODE_SAMPLE
@pytest.mark.asyncio
@@ -20,11 +18,7 @@ def add(a, b):
"""
# write_code_review = WriteCodeReview("write_code_review")
- code = await WriteCodeReview().run(
- context="编写一个从a加b的函数,返回a+b",
- code=code,
- filename="math.py"
- )
+ code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
assert isinstance(code, str)
@@ -33,6 +27,7 @@ def add(a, b):
captured = capfd.readouterr()
print(f"输出内容: {captured.out}")
+
# @pytest.mark.asyncio
# async def test_write_code_review_directly():
# code = SEARCH_CODE_SAMPLE
diff --git a/tests/metagpt/actions/test_write_docstring.py b/tests/metagpt/actions/test_write_docstring.py
new file mode 100644
index 000000000..82d96e1a6
--- /dev/null
+++ b/tests/metagpt/actions/test_write_docstring.py
@@ -0,0 +1,32 @@
+import pytest
+
+from metagpt.actions.write_docstring import WriteDocstring
+
+code = '''
+def add_numbers(a: int, b: int):
+ return a + b
+
+
+class Person:
+ def __init__(self, name: str, age: int):
+ self.name = name
+ self.age = age
+
+ def greet(self):
+ return f"Hello, my name is {self.name} and I am {self.age} years old."
+'''
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("style", "part"),
+ [
+ ("google", "Args:"),
+ ("numpy", "Parameters"),
+ ("sphinx", ":param name:"),
+ ],
+ ids=["google", "numpy", "sphinx"]
+)
+async def test_write_docstring(style: str, part: str):
+ ret = await WriteDocstring().run(code, style=style)
+ assert part in ret
diff --git a/tests/metagpt/actions/test_write_test.py b/tests/metagpt/actions/test_write_test.py
index 7f382e6c2..87a22b139 100644
--- a/tests/metagpt/actions/test_write_test.py
+++ b/tests/metagpt/actions/test_write_test.py
@@ -8,19 +8,35 @@
import pytest
from metagpt.actions.write_test import WriteTest
+from metagpt.logs import logger
@pytest.mark.asyncio
async def test_write_test():
code = """
- def add(a, b):
- return a + b
+ import random
+ from typing import Tuple
+
+ class Food:
+ def __init__(self, position: Tuple[int, int]):
+ self.position = position
+
+ def generate(self, max_y: int, max_x: int):
+ self.position = (random.randint(1, max_y - 1), random.randint(1, max_x - 1))
"""
- write_test = WriteTest("write_test")
+ write_test = WriteTest()
- test_cases = await write_test.run(code)
+ test_code = await write_test.run(
+ code_to_test=code,
+ test_file_name="test_food.py",
+ source_file_path="/some/dummy/path/cli_snake_game/cli_snake_game/food.py",
+ workspace="/some/dummy/path/cli_snake_game"
+ )
+ logger.info(test_code)
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
- assert isinstance(test_cases, str)
- assert len(test_cases) > 0
+ assert isinstance(test_code, str)
+ assert "from cli_snake_game.food import Food" in test_code
+ assert "class TestFood(unittest.TestCase)" in test_code
+ assert "def test_generate" in test_code
diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py
new file mode 100644
index 000000000..01b5dae3b
--- /dev/null
+++ b/tests/metagpt/roles/test_researcher.py
@@ -0,0 +1,32 @@
+from pathlib import Path
+from random import random
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from metagpt.roles import researcher
+
+
+async def mock_llm_ask(self, prompt: str, system_msgs):
+ if "Please provide up to 2 necessary keywords" in prompt:
+ return '["dataiku", "datarobot"]'
+ elif "Provide up to 4 queries related to your research topic" in prompt:
+ return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
+ '"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
+ elif "sort the remaining search results" in prompt:
+ return '[1,2]'
+ elif "Not relevant." in prompt:
+ return "Not relevant" if random() > 0.5 else prompt[-100:]
+ elif "provide a detailed research report" in prompt:
+ return f"# Research Report\n## Introduction\n{prompt}"
+ return ""
+
+
+@pytest.mark.asyncio
+async def test_researcher(mocker):
+ with TemporaryDirectory() as dirname:
+ topic = "dataiku vs. datarobot"
+ mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
+ researcher.RESEARCH_PATH = Path(dirname)
+ await researcher.Researcher().run(topic)
+ assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
diff --git a/tests/metagpt/roles/ui_role.py b/tests/metagpt/roles/ui_role.py
index 101be9c69..a45a89cde 100644
--- a/tests/metagpt/roles/ui_role.py
+++ b/tests/metagpt/roles/ui_role.py
@@ -2,22 +2,19 @@
# @Date : 2023/7/15 16:40
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
-import re
import os
-from importlib import import_module
+import re
from functools import wraps
+from importlib import import_module
-from metagpt.logs import logger
-from metagpt.actions import Action, ActionOutput
-from metagpt.roles import ProductManager, Role
-from metagpt.schema import Message
+from metagpt.actions import Action, ActionOutput, WritePRD
from metagpt.const import WORKSPACE_ROOT
-
-from metagpt.actions import WritePRD
-from metagpt.software_company import SoftwareCompany
+from metagpt.logs import logger
+from metagpt.roles import Role
+from metagpt.schema import Message
from metagpt.tools.sd_engine import SDEngine
-PROMPT_TEMPLATE = '''
+PROMPT_TEMPLATE = """
# Context
{context}
@@ -34,9 +31,9 @@ Attention: Use '##' to split sections, not '#', and '## ' SHOULD W
## CSS Styles (styles.css):Provide as Plain text,use standard css code
## Anything UNCLEAR:Provide as Plain text. Make clear here.
-'''
+"""
-FORMAT_EXAMPLE = '''
+FORMAT_EXAMPLE = """
## UI Design Description
```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ```
@@ -126,7 +123,7 @@ body {
## Anything UNCLEAR
There are no unclear points.
-'''
+"""
OUTPUT_MAPPING = {
"UI Design Description": (str, ...),
@@ -139,25 +136,25 @@ OUTPUT_MAPPING = {
def load_engine(func):
"""Decorator to load an engine by file name and engine name."""
-
+
@wraps(func)
def wrapper(*args, **kwargs):
file_name, engine_name = func(*args, **kwargs)
- engine_file = import_module(file_name, package='metagpt')
+ engine_file = import_module(file_name, package="metagpt")
ip_module_cls = getattr(engine_file, engine_name)
try:
engine = ip_module_cls()
except:
engine = None
-
+
return engine
-
+
return wrapper
def parse(func):
"""Decorator to parse information using regex pattern."""
-
+
@wraps(func)
def wrapper(*args, **kwargs):
context, pattern = func(*args, **kwargs)
@@ -168,30 +165,30 @@ def parse(func):
else:
text_info = context
logger.info("未找到匹配的内容")
-
+
return text_info
-
+
return wrapper
class UIDesign(Action):
"""Class representing the UI Design action."""
-
+
def __init__(self, name, context=None, llm=None):
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
-
+
@parse
def parse_requirement(self, context: str):
"""Parse UI Design draft from the context using regex."""
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
return context, pattern
-
+
@parse
def parse_ui_elements(self, context: str):
"""Parse Selected Elements from the context using regex."""
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
return context, pattern
-
+
@parse
def parse_css_code(self, context: str):
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
@@ -201,7 +198,7 @@ class UIDesign(Action):
def parse_html_code(self, context: str):
pattern = r"```html.*?\n(.*?)```"
return context, pattern
-
+
async def draw_icons(self, context, *args, **kwargs):
"""Draw icons using SDEngine."""
engine = SDEngine()
@@ -215,20 +212,20 @@ class UIDesign(Action):
prompts_batch.append(prompt)
await engine.run_t2i(prompts_batch)
logger.info("Finish icon design using StableDiffusion API")
-
+
async def _save(self, css_content, html_content):
- save_dir = WORKSPACE_ROOT / "resources" / 'codes'
+ save_dir = WORKSPACE_ROOT / "resources" / "codes"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
# Save CSS and HTML content to files
- css_file_path = save_dir / f"ui_design.css"
- html_file_path = save_dir / f"ui_design.html"
-
- with open(css_file_path, 'w') as css_file:
+ css_file_path = save_dir / "ui_design.css"
+ html_file_path = save_dir / "ui_design.html"
+
+ with open(css_file_path, "w") as css_file:
css_file.write(css_content)
- with open(html_file_path, 'w') as html_file:
+ with open(html_file_path, "w") as html_file:
html_file.write(html_content)
-
+
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
"""Run the UI Design action."""
# fixme: update prompt (根据需求细化prompt)
@@ -249,23 +246,27 @@ class UIDesign(Action):
class UI(Role):
"""Class representing the UI Role."""
-
- def __init__(self, name="Catherine", profile="UI Design",
- goal="Finish a workable and good User Interface design based on a product design",
- constraints="Give clear layout description and use standard icons to finish the design",
- skills=["SD"]):
+
+ def __init__(
+ self,
+ name="Catherine",
+ profile="UI Design",
+ goal="Finish a workable and good User Interface design based on a product design",
+ constraints="Give clear layout description and use standard icons to finish the design",
+ skills=["SD"],
+ ):
super().__init__(name, profile, goal, constraints)
self.load_skills(skills)
self._init_actions([UIDesign])
self._watch([WritePRD])
-
+
@load_engine
def load_sd_engine(self):
"""Load the SDEngine."""
file_name = ".tools.sd_engine"
engine_name = "SDEngine"
return file_name, engine_name
-
+
def load_skills(self, skills):
"""Load skills for the UI Role."""
# todo: 添加其他出图engine
@@ -273,4 +274,3 @@ class UI(Role):
if skill == "SD":
self.sd_engine = self.load_sd_engine()
logger.info(f"load skill engine {self.sd_engine}")
-
diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py
index 2418c7b26..a7fe063a6 100644
--- a/tests/metagpt/tools/test_search_engine.py
+++ b/tests/metagpt/tools/test_search_engine.py
@@ -5,24 +5,44 @@
@Author : alexanderwu
@File : test_search_engine.py
"""
+from __future__ import annotations
import pytest
from metagpt.logs import logger
+from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
+class MockSearchEnine:
+ async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
+ rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
+ return "\n".join(rets) if as_string else rets
+
+
@pytest.mark.asyncio
-@pytest.mark.usefixtures("llm_api")
-async def test_search_engine(llm_api):
- search_engine = SearchEngine()
- poetries = [
- # ("北京美食", "北京"),
- ("屈臣氏", "屈臣氏")
- ]
- for i, j in poetries:
- rsp = await search_engine.run(i)
- # rsp = context.llm.ask_batch([prompt])
- logger.info(rsp)
- # assert any(j in k['body'] for k in rsp)
- assert len(rsp) > 0
+@pytest.mark.parametrize(
+ ("search_engine_typpe", "run_func", "max_results", "as_string"),
+ [
+ (SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
+ (SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
+ (SearchEngineType.DIRECT_GOOGLE, None, 8, True),
+ (SearchEngineType.DIRECT_GOOGLE, None, 6, False),
+ (SearchEngineType.SERPER_GOOGLE, None, 8, True),
+ (SearchEngineType.SERPER_GOOGLE, None, 6, False),
+ (SearchEngineType.DUCK_DUCK_GO, None, 8, True),
+ (SearchEngineType.DUCK_DUCK_GO, None, 6, False),
+ (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
+ (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
+
+ ],
+)
+async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
+ search_engine = SearchEngine(search_engine_typpe, run_func)
+ rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
+ logger.info(rsp)
+ if as_string:
+ assert isinstance(rsp, str)
+ else:
+ assert isinstance(rsp, list)
+ assert len(rsp) == max_results
diff --git a/tests/metagpt/tools/test_web_browser_engine.py b/tests/metagpt/tools/test_web_browser_engine.py
index 57335de9c..b08d0ca10 100644
--- a/tests/metagpt/tools/test_web_browser_engine.py
+++ b/tests/metagpt/tools/test_web_browser_engine.py
@@ -1,6 +1,6 @@
import pytest
-from metagpt.config import Config
-from metagpt.tools import web_browser_engine, WebBrowserEngineType
+
+from metagpt.tools import WebBrowserEngineType, web_browser_engine
@pytest.mark.asyncio
diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py
index 908f92112..69e1339e7 100644
--- a/tests/metagpt/tools/test_web_browser_engine_playwright.py
+++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py
@@ -1,4 +1,5 @@
import pytest
+
from metagpt.config import CONFIG
from metagpt.tools import web_browser_engine_playwright
@@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
CONFIG.global_proxy = proxy
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs)
result = await browser.run(url)
+ result = result.inner_text
assert isinstance(result, str)
assert "Deepwisdom" in result
diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py
index 5ea1e3083..ce322f7bd 100644
--- a/tests/metagpt/tools/test_web_browser_engine_selenium.py
+++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py
@@ -1,4 +1,5 @@
import pytest
+
from metagpt.config import CONFIG
from metagpt.tools import web_browser_engine_selenium
@@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
CONFIG.global_proxy = proxy
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type)
result = await browser.run(url)
+ result = result.inner_text
assert isinstance(result, str)
assert "Deepwisdom" in result
@@ -27,7 +29,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
- assert all(("Deepwisdom" in i) for i in results)
+ assert all(("Deepwisdom" in i.inner_text) for i in results)
if use_proxy:
assert "Proxy:" in capfd.readouterr().out
finally:
diff --git a/tests/metagpt/utils/test_output_parser.py b/tests/metagpt/utils/test_output_parser.py
index 155297860..c56cff6fa 100644
--- a/tests/metagpt/utils/test_output_parser.py
+++ b/tests/metagpt/utils/test_output_parser.py
@@ -19,7 +19,7 @@ def test_parse_blocks():
def test_parse_code():
- test_text = "```python\nprint('Hello, world!')\n```"
+ test_text = "```python\nprint('Hello, world!')```"
expected_result = "print('Hello, world!')"
assert OutputParser.parse_code(test_text, 'python') == expected_result
@@ -27,6 +27,22 @@ def test_parse_code():
OutputParser.parse_code(test_text, 'java')
+def test_parse_python_code():
+ expected_result = "print('Hello, world!')"
+ assert OutputParser.parse_python_code("```python\nprint('Hello, world!')```") == expected_result
+ assert OutputParser.parse_python_code("```python\nprint('Hello, world!')") == expected_result
+ assert OutputParser.parse_python_code("print('Hello, world!')") == expected_result
+ assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result
+ assert OutputParser.parse_python_code("print('Hello, world!')```") == expected_result
+ expected_result = "print('```Hello, world!```')"
+ assert OutputParser.parse_python_code("```python\nprint('```Hello, world!```')```") == expected_result
+ assert OutputParser.parse_python_code("The code is: ```python\nprint('```Hello, world!```')```") == expected_result
+ assert OutputParser.parse_python_code("xxx.\n```python\nprint('```Hello, world!```')```\nxxx") == expected_result
+
+ with pytest.raises(ValueError):
+ OutputParser.parse_python_code("xxx =")
+
+
def test_parse_str():
test_text = "name = 'Alice'"
expected_result = 'Alice'
diff --git a/tests/metagpt/utils/test_parse_html.py b/tests/metagpt/utils/test_parse_html.py
new file mode 100644
index 000000000..42be416a6
--- /dev/null
+++ b/tests/metagpt/utils/test_parse_html.py
@@ -0,0 +1,68 @@
+from metagpt.utils import parse_html
+
+PAGE = """
+
+
+
+ Random HTML Example
+
+
+ This is a Heading
+ This is a paragraph with a link and some emphasized text.
+
+ - Item 1
+ - Item 2
+ - Item 3
+
+
+ - Numbered Item 1
+ - Numbered Item 2
+ - Numbered Item 3
+
+
+
+ Header 1
+ Header 2
+
+
+ Row 1, Cell 1
+ Row 1, Cell 2
+
+
+ Row 2, Cell 1
+ Row 2, Cell 2
+
+
+
+
+
+
+
+"""
+
+CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
+'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
+'with a class "box".a link'
+
+
+def test_web_page():
+ page = parse_html.WebPage(inner_text=CONTENT, html=PAGE, url="http://example.com")
+ assert page.title == "Random HTML Example"
+ assert list(page.get_links()) == ["http://example.com/test", "https://metagpt.com"]
+
+
+def test_get_page_content():
+ ret = parse_html.get_html_content(PAGE, "http://example.com")
+ assert ret == CONTENT
diff --git a/tests/metagpt/utils/test_pycst.py b/tests/metagpt/utils/test_pycst.py
new file mode 100644
index 000000000..07352eac2
--- /dev/null
+++ b/tests/metagpt/utils/test_pycst.py
@@ -0,0 +1,136 @@
+from metagpt.utils import pycst
+
+code = '''
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from typing import overload
+
+@overload
+def add_numbers(a: int, b: int):
+ ...
+
+@overload
+def add_numbers(a: float, b: float):
+ ...
+
+def add_numbers(a: int, b: int):
+ return a + b
+
+
+class Person:
+ def __init__(self, name: str, age: int):
+ self.name = name
+ self.age = age
+
+ def greet(self):
+ return f"Hello, my name is {self.name} and I am {self.age} years old."
+'''
+
+documented_code = '''
+"""
+This is an example module containing a function and a class definition.
+"""
+
+
+def add_numbers(a: int, b: int):
+ """This function is used to add two numbers and return the result.
+
+ Parameters:
+ a: The first integer.
+ b: The second integer.
+
+ Returns:
+ int: The sum of the two numbers.
+ """
+ return a + b
+
+class Person:
+ """This class represents a person's information, including name and age.
+
+ Attributes:
+ name: The person's name.
+ age: The person's age.
+ """
+
+ def __init__(self, name: str, age: int):
+ """Creates a new instance of the Person class.
+
+ Parameters:
+ name: The person's name.
+ age: The person's age.
+ """
+ ...
+
+ def greet(self):
+ """
+ Returns a greeting message including the name and age.
+
+ Returns:
+ str: The greeting message.
+ """
+ ...
+'''
+
+
+merged_code = '''
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+This is an example module containing a function and a class definition.
+"""
+
+from typing import overload
+
+@overload
+def add_numbers(a: int, b: int):
+ ...
+
+@overload
+def add_numbers(a: float, b: float):
+ ...
+
+def add_numbers(a: int, b: int):
+ """This function is used to add two numbers and return the result.
+
+ Parameters:
+ a: The first integer.
+ b: The second integer.
+
+ Returns:
+ int: The sum of the two numbers.
+ """
+ return a + b
+
+
+class Person:
+ """This class represents a person's information, including name and age.
+
+ Attributes:
+ name: The person's name.
+ age: The person's age.
+ """
+ def __init__(self, name: str, age: int):
+ """Creates a new instance of the Person class.
+
+ Parameters:
+ name: The person's name.
+ age: The person's age.
+ """
+ self.name = name
+ self.age = age
+
+ def greet(self):
+ """
+ Returns a greeting message including the name and age.
+
+ Returns:
+ str: The greeting message.
+ """
+ return f"Hello, my name is {self.name} and I am {self.age} years old."
+'''
+
+
+def test_merge_docstring():
+ data = pycst.merge_docstring(code, documented_code)
+ print(data)
+ assert data == merged_code
diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py
index de8ccba4c..69f317f79 100644
--- a/tests/metagpt/utils/test_serialize.py
+++ b/tests/metagpt/utils/test_serialize.py
@@ -3,94 +3,64 @@
# @Desc : the unittest of serialize
from typing import List, Tuple
-import pytest
-from pydantic import create_model
-
-from metagpt.actions.action_output import ActionOutput
from metagpt.actions import WritePRD
+from metagpt.actions.action_output import ActionOutput
from metagpt.schema import Message
-from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message
+from metagpt.utils.serialize import (
+ actionoutout_schema_to_mapping,
+ deserialize_message,
+ serialize_message,
+)
def test_actionoutout_schema_to_mapping():
- schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'string'
- }
- }
- }
+ schema = {"title": "test", "type": "object", "properties": {"field": {"title": "field", "type": "string"}}}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (str, ...)
+ assert mapping["field"] == (str, ...)
schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'array',
- 'items': {
- 'type': 'string'
- }
- }
- }
+ "title": "test",
+ "type": "object",
+ "properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}},
}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (List[str], ...)
+ assert mapping["field"] == (List[str], ...)
schema = {
- 'title': 'test',
- 'type': 'object',
- 'properties': {
- 'field': {
- 'title': 'field',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'minItems': 2,
- 'maxItems': 2,
- 'items': [
- {
- 'type': 'string'
- },
- {
- 'type': 'string'
- }
- ]
- }
+ "title": "test",
+ "type": "object",
+ "properties": {
+ "field": {
+ "title": "field",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "minItems": 2,
+ "maxItems": 2,
+ "items": [{"type": "string"}, {"type": "string"}],
+ },
}
- }
+ },
}
mapping = actionoutout_schema_to_mapping(schema)
- assert mapping['field'] == (List[Tuple[str, str]], ...)
+ assert mapping["field"] == (List[Tuple[str, str]], ...)
assert True, True
def test_serialize_and_deserialize_message():
- out_mapping = {
- 'field1': (str, ...),
- 'field2': (List[str], ...)
- }
- out_data = {
- 'field1': 'field1 value',
- 'field2': ['field2 value1', 'field2 value2']
- }
- ic_obj = ActionOutput.create_model_class('prd', out_mapping)
+ out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
+ out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
+ ic_obj = ActionOutput.create_model_class("prd", out_mapping)
- message = Message(content='prd demand',
- instruct_content=ic_obj(**out_data),
- role='user',
- cause_by=WritePRD) # WritePRD as test action
+ message = Message(
+ content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
+ ) # WritePRD as test action
message_ser = serialize_message(message)
new_message = deserialize_message(message_ser)
assert new_message.content == message.content
assert new_message.cause_by == message.cause_by
- assert new_message.instruct_content.field1 == out_data['field1']
+ assert new_message.instruct_content.field1 == out_data["field1"]
diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py
new file mode 100644
index 000000000..0caf8abaa
--- /dev/null
+++ b/tests/metagpt/utils/test_text.py
@@ -0,0 +1,77 @@
+import pytest
+
+from metagpt.utils.text import (
+ decode_unicode_escape,
+ generate_prompt_chunk,
+ reduce_message_length,
+ split_paragraph,
+)
+
+
+def _msgs():
+ length = 20
+ while length:
+ yield "Hello," * 1000 * length
+ length -= 1
+
+
+def _paragraphs(n):
+ return " ".join("Hello World." for _ in range(n))
+
+
+@pytest.mark.parametrize(
+ "msgs, model_name, system_text, reserved, expected",
+ [
+ (_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
+ (_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
+ (_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
+ (_msgs(), "gpt-4", "System", 2000, 3),
+ (_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
+ (_msgs(), "gpt-4-32k", "System", 4000, 14),
+ (_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
+ ]
+)
+def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
+ assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
+
+
+@pytest.mark.parametrize(
+ "text, prompt_template, model_name, system_text, reserved, expected",
+ [
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
+ (" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
+ (" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
+ ]
+)
+def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
+ ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
+ assert len(ret) == expected
+
+
+@pytest.mark.parametrize(
+ "paragraph, sep, count, expected",
+ [
+ (_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
+ (_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
+ (f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
+ ("......", ".", 2, ["...", "..."]),
+ ("......", ".", 3, ["..", "..", ".."]),
+ (".......", ".", 2, ["....", "..."]),
+ ]
+)
+def test_split_paragraph(paragraph, sep, count, expected):
+ ret = split_paragraph(paragraph, sep, count)
+ assert ret == expected
+
+
+@pytest.mark.parametrize(
+ "text, expected",
+ [
+ ("Hello\\nWorld", "Hello\nWorld"),
+ ("Hello\\tWorld", "Hello\tWorld"),
+ ("Hello\\u0020World", "Hello World"),
+ ]
+)
+def test_decode_unicode_escape(text, expected):
+ assert decode_unicode_escape(text) == expected