refactor: pre-commit run --all-files

This commit is contained in:
莘权 马 2023-11-22 16:26:48 +08:00
parent d8adba99d4
commit cda032948f
129 changed files with 812 additions and 831 deletions

View file

@ -6,14 +6,14 @@
@File : conftest.py
"""
import asyncio
import re
from unittest.mock import Mock
import pytest
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
import asyncio
import re
class Context:

View file

@ -311,12 +311,10 @@ TASKS = [
"添加数据API接受用户输入的文档库对文档库进行索引\n- 使用MeiliSearch连接并添加文档库",
"搜索API接收用户输入的关键词返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据",
"多条件筛选API接收用户选择的筛选条件返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果",
"智能推荐API根据用户的搜索历史记录和搜索行为推荐相关的搜索结果。"
"智能推荐API根据用户的搜索历史记录和搜索行为推荐相关的搜索结果。",
]
TASKS_2 = [
"完成main.py的功能"
]
TASKS_2 = ["完成main.py的功能"]
SEARCH_CODE_SAMPLE = """
import requests
@ -460,7 +458,7 @@ if __name__ == '__main__':
print('No results found.')
'''
MEILI_CODE = '''import meilisearch
MEILI_CODE = """import meilisearch
from typing import List
@ -496,9 +494,9 @@ if __name__ == '__main__':
# 添加文档库到搜索引擎
search_engine.add_documents(books_data_source, documents)
'''
"""
MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
Traceback (most recent call last):
File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in <module>
search_engine.add_documents(books_data_source, documents)
@ -506,7 +504,7 @@ Traceback (most recent call last):
index = self.client.get_or_create_index(index_name)
AttributeError: 'Client' object has no attribute 'get_or_create_index'
Process finished with exit code 1'''
Process finished with exit code 1"""
MEILI_CODE_REFINED = """
"""

View file

@ -9,18 +9,21 @@ from typing import List, Tuple
from metagpt.actions import ActionOutput
t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n",
"Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n",
"Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n",
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."]],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"}
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
"Logic Analysis": [
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
["static/js/script.js", "Handles user interactions and updates the game UI."],
["static/css/styles.css", "Defines the styles for the game UI."],
["templates/index.html", "The main page of the web application. Displays the game UI."],
],
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
@ -45,6 +48,6 @@ def test_create_model_class_with_mapping():
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
if __name__ == '__main__':
if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()

View file

@ -10,12 +10,7 @@ from metagpt.actions.azure_tts import AzureTTS
def test_azure_tts():
azure_tts = AzureTTS("azure_tts")
azure_tts.synthesize_speech(
"zh-CN",
"zh-CN-YunxiNeural",
"Boy",
"你好,我是卡卡",
"output.wav")
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
# 运行需要先配置 SUBSCRIPTION_KEY
# TODO: 这里如果要检验还要额外加上对应的asr才能确保前后生成是接近一致的但现在还没有

View file

@ -2,7 +2,6 @@ import pytest
from metagpt.actions.clone_function import CloneFunction, run_function_code
source_code = """
import pandas as pd
import ta
@ -31,14 +30,18 @@ def get_expected_res():
import ta
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@ -46,9 +49,9 @@ def get_expected_res():
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert 'def ' in code
stock_path = './tests/data/baba_stock.csv'
df, msg = run_function_code(code, 'stock_indicator', stock_path)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)

View file

@ -144,12 +144,12 @@ Engineer
---
'''
@pytest.mark.asyncio
async def test_debug_error():
debug_error = DebugError("debug_error")
file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
assert "class Player" in rewritten_code # rewrite the same class
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
assert "class Player" in rewritten_code # rewrite the same class
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")

View file

@ -10,6 +10,7 @@ import pytest
from metagpt.actions.detail_mining import DetailMining
from metagpt.logs import logger
@pytest.mark.asyncio
async def test_detail_mining():
topic = "如何做一个生日蛋糕"
@ -17,7 +18,6 @@ async def test_detail_mining():
detail_mining = DetailMining("detail_mining")
rsp = await detail_mining.run(topic=topic, record=record)
logger.info(f"{rsp.content=}")
assert '##OUTPUT' in rsp.content
assert '蛋糕' in rsp.content
assert "##OUTPUT" in rsp.content
assert "蛋糕" in rsp.content

View file

@ -8,12 +8,11 @@
"""
import os
from typing import List
import pytest
from pathlib import Path
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
import pytest
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
@pytest.mark.asyncio
@ -22,7 +21,7 @@ from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
]
],
)
async def test_invoice_ocr(invoice_path: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -35,18 +34,8 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
(
"../../data/invoices/invoice-1.pdf",
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": "412.00",
"开票日期": "2023年02月03日"
}
]
),
]
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -59,9 +48,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]):
@pytest.mark.asyncio
@pytest.mark.parametrize(
("invoice_path", "query", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
]
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
)
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
@ -69,4 +56,3 @@ async def test_reply_question(invoice_path: str, query: dict, expected_result: s
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
assert expected_result in result

View file

@ -4,7 +4,7 @@
#
from tests.metagpt.roles.ui_role import UIDesign
llm_resp= '''
llm_resp = """
# UI Design Description
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
@ -98,12 +98,13 @@ body {
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
'''
"""
def test_ui_design_parse_css():
ui_design_work = UIDesign(name="UI design action")
css = '''
css = """
body {
display: flex;
flex-direction: column;
@ -160,14 +161,14 @@ def test_ui_design_parse_css():
left: 50%;
transform: translate(-50%, -50%);
font-size: 3em;
'''
assert ui_design_work.parse_css_code(context=llm_resp)==css
"""
assert ui_design_work.parse_css_code(context=llm_resp) == css
def test_ui_design_parse_html():
ui_design_work = UIDesign(name="UI design action")
html = '''
html = """
<!DOCTYPE html>
<html lang="en">
<head>
@ -184,8 +185,5 @@ def test_ui_design_parse_html():
<div class="game-over">Game Over</div>
</body>
</html>
'''
assert ui_design_work.parse_css_code(context=llm_resp)==html
"""
assert ui_design_work.parse_css_code(context=llm_resp) == html

View file

@ -22,13 +22,13 @@ async def test_write_code():
logger.info(code)
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
assert 'def add' in code
assert 'return' in code
assert "def add" in code
assert "return" in code
@pytest.mark.asyncio
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
llm = LLM()
rsp = await llm.aask(prompt)
logger.info(rsp)

View file

@ -2,7 +2,7 @@ import pytest
from metagpt.actions.write_docstring import WriteDocstring
code = '''
code = """
def add_numbers(a: int, b: int):
return a + b
@ -14,7 +14,7 @@ class Person:
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
"""
@pytest.mark.asyncio
@ -25,7 +25,7 @@ class Person:
("numpy", "Parameters"),
("sphinx", ":param name:"),
],
ids=["google", "numpy", "sphinx"]
ids=["google", "numpy", "sphinx"],
)
async def test_write_docstring(style: str, part: str):
ret = await WriteDocstring().run(code, style=style)

View file

@ -9,14 +9,11 @@ from typing import Dict
import pytest
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
@pytest.mark.asyncio
@pytest.mark.parametrize(
("language", "topic"),
[("English", "Write a tutorial about Python")]
)
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
async def test_write_directory(language: str, topic: str):
ret = await WriteDirectory(language=language).run(topic=topic)
assert isinstance(ret, dict)
@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
("language", "topic", "directory"),
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})]
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
)
async def test_write_content(language: str, topic: str, directory: Dict):
ret = await WriteContent(language=language, directory=directory).run(topic=topic)

View file

@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore
def test_chroma_store():
"""FIXMEchroma使用感觉很诡异一用Python就挂测试用例里也是"""
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
document_store = ChromaStore('sample_collection_1')
document_store = ChromaStore("sample_collection_1")
# 使用 write 方法添加多个文档
document_store.write(["This is document1", "This is document2"],
[{"source": "google-docs"}, {"source": "notion"}],
["doc1", "doc2"])
document_store.write(
["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"]
)
# 使用 add 方法添加一个文档
document_store.add("This is document3", {"source": "notion"}, "doc3")

View file

@ -39,11 +39,11 @@ user: 没有了
@pytest.mark.asyncio
async def test_faiss_store_search():
store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
store.add(['油皮洗面奶'])
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
store.add(["油皮洗面奶"])
role = Sales(store=store)
queries = ['油皮洗面奶', '介绍下欧莱雅的']
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
for query in queries:
rsp = await role.run(query)
assert rsp
@ -60,7 +60,10 @@ def customer_service():
async def test_faiss_store_customer_service():
allq = [
# ["我的餐怎么两小时都没到", "退货吧"],
["你好收不到取餐码,麻烦帮我开箱", "14750187158", ]
[
"你好收不到取餐码,麻烦帮我开箱",
"14750187158",
]
]
role = customer_service()
for queries in allq:
@ -71,4 +74,4 @@ async def test_faiss_store_customer_service():
def test_faiss_store_no_file():
with pytest.raises(FileNotFoundError):
FaissStore(DATA_PATH / 'wtf.json')
FaissStore(DATA_PATH / "wtf.json")

View file

@ -5,27 +5,33 @@
@Author : unkn-wn (Leon Yee)
@File : test_lancedb_store.py
"""
from metagpt.document_store.lancedb_store import LanceStore
import pytest
import random
import pytest
from metagpt.document_store.lancedb_store import LanceStore
@pytest
def test_lance_store():
# This simply establishes the connection to the database, so we can drop the table if it exists
store = LanceStore('test')
store = LanceStore("test")
store.drop('test')
store.drop("test")
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
ids=["doc1", "doc2"])
store.write(
data=[[random.random() for _ in range(100)] for _ in range(2)],
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
ids=["doc1", "doc2"],
)
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
result = store.search([random.random() for _ in range(100)], n_results=3)
assert(len(result) == 3)
assert len(result) == 3
store.delete("doc2")
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
assert(len(result) == 1)
result = store.search(
[random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine"
)
assert len(result) == 1

View file

@ -12,7 +12,7 @@ import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float}
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
@ -25,12 +25,12 @@ book_data = [
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop('Book')
milvus_store.create_collection('Book', book_columns)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index('emb')
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field='emb')
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -24,9 +24,7 @@ random.seed(seed_value)
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
points = [
PointStruct(
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
)
PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10})
for idx, vector in enumerate(vectors)
]
@ -57,9 +55,7 @@ def test_milvus_store():
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
)
assert results[0]["id"] == 8
assert results[0]["score"] == 0.9100373450784073
@ -68,9 +64,7 @@ def test_milvus_store():
results = qdrant_store.search(
"Book",
query=[1.0, 1.0],
query_filter=Filter(
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
),
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
return_vector=True,
)
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]

View file

@ -30,7 +30,7 @@ def test_skill_manager():
rsp = manager.retrieve_skill("写测试用例")
logger.info(rsp)
assert rsp[0] == 'WriteTest'
assert rsp[0] == "WriteTest"
rsp = manager.retrieve_skill_scored("写PRD")
logger.info(rsp)

View file

@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
from metagpt.config import CONFIG
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.roles.role import RoleContext
from metagpt.config import CONFIG
from metagpt.memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
def test_ltm_search():
@ -14,25 +14,25 @@ def test_ltm_search():
openai_api_key = CONFIG.openai_api_key
assert len(openai_api_key) > 20
role_id = 'UTUserLtm(Product Manager)'
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[BossRequirement])
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = 'Write a cli snake game'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
idea = "Write a cli snake game"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
news = ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
news = ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
@ -47,8 +47,8 @@ def test_ltm_search():
news = ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = 'Write a Battle City'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a Battle City"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1

View file

@ -4,17 +4,16 @@
from typing import List
from metagpt.actions import BossRequirement, WritePRD
from metagpt.actions.action_output import ActionOutput
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.actions import WritePRD
from metagpt.actions.action_output import ActionOutput
def test_idea_message():
idea = 'Write a cli snake game'
role_id = 'UTUser1(Product Manager)'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
idea = "Write a cli snake game"
role_id = "UTUser1(Product Manager)"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
@ -23,13 +22,13 @@ def test_idea_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
assert len(new_messages) == 0 # similar, return []
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content
@ -38,22 +37,15 @@ def test_idea_message():
def test_actionout_message():
out_mapping = {
'field1': (str, ...),
'field2': (List[str], ...)
}
out_data = {
'field1': 'field1 value',
'field2': ['field2 value1', 'field2 value2']
}
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
role_id = 'UTUser2(Architect)'
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
message = Message(content=content,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD) # WritePRD as test action
role_id = "UTUser2(Architect)"
content = "The boss has requested the creation of a command-line interface (CLI) snake game"
message = Message(
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
@ -62,19 +54,13 @@ def test_actionout_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_conent = 'The request is command-line interface (CLI) snake game'
sim_message = Message(content=sim_conent,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD)
sim_conent = "The request is command-line interface (CLI) snake game"
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(sim_message)
assert len(new_messages) == 0 # similar, return []
assert len(new_messages) == 0 # similar, return []
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
new_message = Message(content=new_conent,
instruct_content=ic_obj(**out_data),
role='user',
cause_by=WritePRD)
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search(new_message)
assert new_messages[0].content == message.content

View file

@ -10,6 +10,6 @@ from metagpt.schema import Message
def test_message():
message = Message(role='user', content='wtf')
assert 'role' in message.to_dict()
assert 'user' in str(message)
message = Message(role="user", content="wtf")
assert "role" in message.to_dict()
assert "user" in str(message)

View file

@ -6,6 +6,6 @@ def test_message():
llm = SparkAPI()
logger.info(llm.ask('只回答"收到了"这三个字。'))
result = llm.ask('写一篇五百字的日记')
result = llm.ask("写一篇五百字的日记")
logger.info(result)
assert len(result) > 100

View file

@ -71,7 +71,7 @@ PRD = '''## 原始需求
```
'''
SYSTEM_DESIGN = '''## Python package name
SYSTEM_DESIGN = """## Python package name
```python
"smart_search_engine"
```
@ -149,10 +149,10 @@ sequenceDiagram
S-->>SE: return summary
SE-->>M: return summary
```
'''
"""
TASKS = '''## Logic Analysis
TASKS = """## Logic Analysis
在这个项目中所有的模块都依赖于SearchEngine这是主入口其他的模块IndexRanking和Summary都通过它交互另外"Index"类又依赖于"KnowledgeBase"因为它需要从知识库中获取数据
@ -181,7 +181,7 @@ task_list = [
]
```
这个任务列表首先定义了最基础的模块然后是依赖这些模块的模块最后是辅助模块可以根据团队的能力和资源同时开发多个任务只要满足依赖关系例如在开发"search.py"之前可以同时开发"knowledge_base.py""index.py""ranking.py""summary.py"
'''
"""
TASKS_TOMATO_CLOCK = '''## Required Python third-party packages: Provided in requirements.txt format
@ -224,30 +224,30 @@ task_list = [
TASK = """smart_search_engine/knowledge_base.py"""
STRS_FOR_PARSING = [
"""
"""
## 1
```python
a
```
""",
"""
"""
##2
```python
"a"
```
""",
"""
"""
## 3
```python
a = "a"
```
""",
"""
"""
## 4
```python
a = 'a'
```
"""
""",
]

View file

@ -35,13 +35,13 @@ def test_parse_str():
for idx, i in enumerate(STRS_FOR_PARSING):
text = CodeParser.parse_str(f"{idx+1}", i)
# logger.info(text)
assert text == 'a'
assert text == "a"
def test_parse_blocks():
tasks = CodeParser.parse_blocks(TASKS)
logger.info(tasks.keys())
assert 'Task list' in tasks.keys()
assert "Task list" in tasks.keys()
target_list = [

View file

@ -9,8 +9,8 @@
from pathlib import Path
import pytest
import pandas as pd
import pytest
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
from metagpt.schema import Message
@ -24,82 +24,39 @@ from metagpt.schema import Message
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
}
]
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
}
]
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
),
(
"Invoicing date",
Path("../../data/invoices/invoice-4.zip"),
Path("../../../data/invoice_table/invoice-4.xlsx"),
[
{
"收款人": "小明",
"城市": "深圳市",
"总费用/元": 412.00,
"开票日期": "2023年02月03日"
},
{
"收款人": "铁头",
"城市": "广州市",
"总费用/元": 898.00,
"开票日期": "2023年03月17日"
},
{
"收款人": "夏天",
"城市": "福州市",
"总费用/元": 2462.00,
"开票日期": "2023年08月26日"
}
]
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
],
),
]
],
)
async def test_invoice_ocr_assistant(
query: str,
invoice_path: Path,
invoice_table_path: Path,
expected_result: list[dict]
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
):
invoice_path = Path.cwd() / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(
content=query,
instruct_content={"file_path": invoice_path}
))
await role.run(Message(content=query, instruct_content={"file_path": invoice_path}))
invoice_table_path = Path.cwd() / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient='records')
dict_result = df.to_dict(orient="records")
assert dict_result == expected_result

View file

@ -11,10 +11,12 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["dataiku", "datarobot"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
return (
'["Dataiku machine learning platform", "DataRobot AI platform comparison", '
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
)
elif "sort the remaining search results" in prompt:
return '[1,2]'
return "[1,2]"
elif "Not relevant." in prompt:
return "Not relevant" if random() > 0.5 else prompt[-100:]
elif "provide a detailed research report" in prompt:

View file

@ -12,10 +12,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
@pytest.mark.asyncio
@pytest.mark.parametrize(
("language", "topic"),
[("Chinese", "Write a tutorial about Python")]
)
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")])
async def test_tutorial_assistant(language: str, topic: str):
topic = "Write a tutorial about MySQL"
role = TutorialAssistant(language=language)
@ -24,4 +21,4 @@ async def test_tutorial_assistant(language: str, topic: str):
title = filename.split("/")[-1].split(".")[0]
async with aiofiles.open(filename, mode="r") as reader:
content = await reader.read()
assert content.startswith(f"# {title}")
assert content.startswith(f"# {title}")

View file

@ -2,9 +2,8 @@
# @Date : 2023/7/22 02:40
# @Author : stellahong (stellahong@fuzhi.ai)
#
from metagpt.software_company import SoftwareCompany
from metagpt.roles import ProductManager
from metagpt.software_company import SoftwareCompany
from tests.metagpt.roles.ui_role import UI

View file

@ -14,7 +14,7 @@ from metagpt.logs import logger
@pytest.mark.usefixtures("llm_api")
class TestGPT:
def test_llm_api_ask(self, llm_api):
answer = llm_api.ask('hello chatgpt')
answer = llm_api.ask("hello chatgpt")
assert len(answer) > 0
# def test_gptapi_ask_batch(self, llm_api):
@ -22,22 +22,22 @@ class TestGPT:
# assert len(answer) > 0
def test_llm_api_ask_code(self, llm_api):
answer = llm_api.ask_code(['请扮演一个Google Python专家工程师如果理解回复明白', '写一个hello world'])
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"])
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_aask(self, llm_api):
answer = await llm_api.aask('hello chatgpt')
answer = await llm_api.aask("hello chatgpt")
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_aask_code(self, llm_api):
answer = await llm_api.aask_code(['请扮演一个Google Python专家工程师如果理解回复明白', '写一个hello world'])
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"])
assert len(answer) > 0
@pytest.mark.asyncio
async def test_llm_api_costs(self, llm_api):
await llm_api.aask('hello chatgpt')
await llm_api.aask("hello chatgpt")
costs = llm_api.get_costs()
logger.info(costs)
assert costs.total_cost > 0

View file

@ -18,17 +18,17 @@ def llm():
@pytest.mark.asyncio
async def test_llm_aask(llm):
assert len(await llm.aask('hello world')) > 0
assert len(await llm.aask("hello world")) > 0
@pytest.mark.asyncio
async def test_llm_aask_batch(llm):
assert len(await llm.aask_batch(['hi', 'write python hello world.'])) > 0
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
@pytest.mark.asyncio
async def test_llm_acompletion(llm):
hello_msg = [{'role': 'user', 'content': 'hello'}]
hello_msg = [{"role": "user", "content": "hello"}]
assert len(await llm.acompletion(hello_msg)) > 0
assert len(await llm.acompletion_batch([hello_msg])) > 0
assert len(await llm.acompletion_batch_text([hello_msg])) > 0

View file

@ -11,26 +11,26 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe
def test_message():
msg = Message(role='User', content='WTF')
assert msg.to_dict()['role'] == 'User'
assert 'User' in str(msg)
msg = Message(role="User", content="WTF")
assert msg.to_dict()["role"] == "User"
assert "User" in str(msg)
def test_all_messages():
test_content = 'test_message'
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
for msg in msgs:
assert msg.content == test_content
def test_raw_message():
msg = RawMessage(role='user', content='raw')
assert msg['role'] == 'user'
assert msg['content'] == 'raw'
msg = RawMessage(role="user", content="raw")
assert msg["role"] == "user"
assert msg["content"] == "raw"
with pytest.raises(KeyError):
assert msg['1'] == 1, "KeyError: '1'"
assert msg["1"] == 1, "KeyError: '1'"

View file

@ -9,6 +9,6 @@ from metagpt.roles import Role
def test_role_desc():
i = Role(profile='Sales', desc='Best Seller')
assert i.profile == 'Sales'
assert i._setting.desc == 'Best Seller'
i = Role(profile="Sales", desc="Best Seller")
assert i.profile == "Sales"
assert i._setting.desc == "Best Seller"

View file

@ -9,13 +9,13 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
def test_messages():
test_content = 'test_message'
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
text = str(msgs)
roles = ['user', 'system', 'assistant', 'QA']
roles = ["user", "system", "assistant", "QA"]
assert all([i in text for i in roles])

View file

@ -1,23 +1,22 @@
import pytest
import pandas as pd
from pathlib import Path
from tests.data import sales_desc, store_desc
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
import pandas as pd
import pytest
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
logger.add('./tests/data/test_ci.log')
logger.add("./tests/data/test_ci.log")
stock = "./tests/data/baba_stock.csv"
# TODO: 需要一种表格数据格式能够支持schame管理的标注字段类型和字段含义。
class CreateStockIndicators(Action):
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
指标生成对应的三列: SMA, BB_upper, BB_lower
指标生成对应的三列: SMA, BB_upper, BB_lower
"""
...
@ -25,18 +24,20 @@ class CreateStockIndicators(Action):
@pytest.mark.asyncio
async def test_actions():
# 计算指标
indicators = ['Simple Moving Average', 'BollingerBands']
indicators = ["Simple Moving Average", "BollingerBands"]
stocker = CreateStockIndicators()
df, msg = await stocker.run(stock, indicators=indicators)
assert isinstance(df, pd.DataFrame)
assert 'Close' in df.columns
assert 'Date' in df.columns
assert "Close" in df.columns
assert "Date" in df.columns
# 将df保存为文件将文件路径传入到下一个action
df_path = './tests/data/stock_indicators.csv'
df_path = "./tests/data/stock_indicators.csv"
df.to_csv(df_path)
assert Path(df_path).is_file()
# 可视化指标结果
figure_path = './tests/data/figure_ci.png'
figure_path = "./tests/data/figure_ci.png"
ci_ploter = OpenCodeInterpreter()
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。")
ci_ploter.chat(
f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper布林带上界, BB_lower布林带下界进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算把Date列转换为日期类型。要求图片优美BB_upper, BB_lower之间使用合适的颜色填充。"
)
assert Path(figure_path).is_file()

View file

@ -20,8 +20,9 @@ from metagpt.tools.prompt_writer import (
@pytest.mark.usefixtures("llm_api")
def test_gpt_prompt_generator(llm_api):
generator = GPTPromptGenerator()
example = "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " \
"品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
example = (
"商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
)
results = llm_api.ask_batch(generator.gen(example))
logger.info(results)
@ -46,7 +47,7 @@ def test_enron_template(llm_api):
results = template.gen(subj)
assert len(results) > 0
assert any("Write an email with the subject \"Meeting Agenda\"." in r for r in results)
assert any('Write an email with the subject "Meeting Agenda".' in r for r in results)
def test_beagec_template():
@ -54,5 +55,6 @@ def test_beagec_template():
results = template.gen()
assert len(results) > 0
assert any("Edit and revise this document to improve its grammar, vocabulary, spelling, and style."
in r for r in results)
assert any(
"Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results
)

View file

@ -4,7 +4,7 @@
#
import os
from metagpt.tools.sd_engine import SDEngine, WORKSPACE_ROOT
from metagpt.tools.sd_engine import WORKSPACE_ROOT, SDEngine
def test_sd_engine_init():

View file

@ -16,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
rets = [
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
]
return "\n".join(rets) if as_string else rets
@ -34,10 +36,14 @@ class MockSearchEnine:
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
async def test_search_engine(
search_engine_typpe,
run_func,
max_results,
as_string,
):
search_engine = SearchEngine(search_engine_typpe, run_func)
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
logger.info(rsp)

View file

@ -13,7 +13,7 @@ import pytest
from metagpt.logs import logger
from metagpt.tools.search_engine_meilisearch import DataSource, MeilisearchEngine
MASTER_KEY = '116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk'
MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk"
@pytest.fixture()
@ -29,7 +29,7 @@ def test_meilisearch(search_engine_server):
search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY)
# 假设有一个名为"books"的数据源,包含要添加的文档库
books_data_source = DataSource(name='books', url='https://example.com/books')
books_data_source = DataSource(name="books", url="https://example.com/books")
# 假设有一个名为"documents"的文档库,包含要添加的文档
documents = [
@ -43,4 +43,4 @@ def test_meilisearch(search_engine_server):
# 添加文档库到搜索引擎
search_engine.add_documents(books_data_source, documents)
logger.info(search_engine.search('Book 1'))
logger.info(search_engine.search("Book 1"))

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,7 @@ from metagpt.tools.translator import Translator
def test_translate(llm_api):
poetries = [
("Let life be beautiful like summer flowers", ""),
("The ancient Chinese poetries are all songs.", "中国")
("The ancient Chinese poetries are all songs.", "中国"),
]
for i, j in poetries:
prompt = Translator.translate_prompt(i)

View file

@ -16,8 +16,12 @@ class TestUTWriter:
tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"]
# 这里在文件中手动加入了两个测试标签的API
utg = UTGenerator(swagger_file=swagger_file, ut_py_path=UT_PY_PATH, questions_path=API_QUESTIONS_PATH,
template_prefix=YFT_PROMPT_PREFIX)
utg = UTGenerator(
swagger_file=swagger_file,
ut_py_path=UT_PY_PATH,
questions_path=API_QUESTIONS_PATH,
template_prefix=YFT_PROMPT_PREFIX,
)
ret = utg.generate_ut(include_tags=tags)
# 后续加入对文件生成内容与数量的检验
assert ret

View file

@ -131,10 +131,10 @@ class TestCodeParser:
def test_parse_file_list(self, parser, text):
result = parser.parse_file_list("Task list", text)
print(result)
assert result == ['task1', 'task2']
assert result == ["task1", "task2"]
if __name__ == '__main__':
if __name__ == "__main__":
t = TestCodeParser()
t.test_parse_file_list(CodeParser(), t_text)
# TestCodeParser.test_parse_file_list()

View file

@ -16,12 +16,12 @@ from metagpt.const import get_project_root
class TestGetProjectRoot:
def change_etc_dir(self):
# current_directory = Path.cwd()
abs_root = '/etc'
abs_root = "/etc"
os.chdir(abs_root)
def test_get_project_root(self):
project_root = get_project_root()
assert project_root.name == 'metagpt'
assert project_root.name == "metagpt"
def test_get_root_exception(self):
with pytest.raises(Exception) as exc_info:

View file

@ -20,12 +20,12 @@ def test_config_class_is_singleton():
def test_config_class_get_key_exception():
with pytest.raises(Exception) as exc_info:
config = Config()
config.get('wtf')
config.get("wtf")
assert str(exc_info.value) == "Key 'wtf' not found in environment variables or in the YAML file"
def test_config_yaml_file_not_exists():
config = Config('wtf.yaml')
config = Config("wtf.yaml")
with pytest.raises(Exception) as exc_info:
config.get('OPENAI_BASE_URL')
config.get("OPENAI_BASE_URL")
assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file"

View file

@ -10,12 +10,12 @@ from metagpt.provider.openai_api import OpenAIGPTAPI
async def try_hello(api):
batch = [[{'role': 'user', 'content': 'hello'}]]
batch = [[{"role": "user", "content": "hello"}]]
results = await api.acompletion_batch_text(batch)
return results
async def aask_batch(api: OpenAIGPTAPI):
results = await api.aask_batch(['hi', 'write python hello world.'])
results = await api.aask_batch(["hi", "write python hello world."])
logger.info(results)
return results

View file

@ -15,12 +15,11 @@ from metagpt.utils.file import File
@pytest.mark.asyncio
@pytest.mark.parametrize(
("root_path", "filename", "content"),
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")]
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")],
)
async def test_write_and_read_file(root_path: Path, filename: str, content: bytes):
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8'))
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode("utf-8"))
assert isinstance(full_file_name, Path)
assert root_path / filename == full_file_name
file_data = await File.read(full_file_name)
assert file_data.decode("utf-8") == content

View file

@ -14,17 +14,17 @@ from metagpt.utils.common import OutputParser
def test_parse_blocks():
test_text = "##block1\nThis is block 1.\n##block2\nThis is block 2."
expected_result = {'block1': 'This is block 1.', 'block2': 'This is block 2.'}
expected_result = {"block1": "This is block 1.", "block2": "This is block 2."}
assert OutputParser.parse_blocks(test_text) == expected_result
def test_parse_code():
test_text = "```python\nprint('Hello, world!')```"
expected_result = "print('Hello, world!')"
assert OutputParser.parse_code(test_text, 'python') == expected_result
assert OutputParser.parse_code(test_text, "python") == expected_result
with pytest.raises(Exception):
OutputParser.parse_code(test_text, 'java')
OutputParser.parse_code(test_text, "java")
def test_parse_python_code():
@ -45,13 +45,13 @@ def test_parse_python_code():
def test_parse_str():
test_text = "name = 'Alice'"
expected_result = 'Alice'
expected_result = "Alice"
assert OutputParser.parse_str(test_text) == expected_result
def test_parse_file_list():
test_text = "files=['file1', 'file2', 'file3']"
expected_result = ['file1', 'file2', 'file3']
expected_result = ["file1", "file2", "file3"]
assert OutputParser.parse_file_list(test_text) == expected_result
with pytest.raises(Exception):
@ -60,7 +60,7 @@ def test_parse_file_list():
def test_parse_data():
test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']"
expected_result = {'block1': "print('Hello, world!')", 'block2': ['file1', 'file2', 'file3']}
expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]}
assert OutputParser.parse_data(test_data) == expected_result
@ -103,9 +103,11 @@ def test_parse_data():
None,
Exception,
),
]
],
)
def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception):
def test_extract_struct(
text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception
):
def case():
resp = OutputParser.extract_struct(text, data_type)
assert resp == parsed_data
@ -117,7 +119,7 @@ def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], par
case()
if __name__ == '__main__':
if __name__ == "__main__":
t_text = '''
## Required Python third-party packages
```python
@ -216,7 +218,7 @@ We need clarification on how the high score should be stored. Should it persist
"Requirement Pool": (List[Tuple[str, str]], ...),
"Anything UNCLEAR": (str, ...),
}
t_text1 = '''## Original Requirements:
t_text1 = """## Original Requirements:
The boss wants to create a web-based version of the game "Fly Bird".
@ -284,7 +286,7 @@ The product should be a web-based version of the game "Fly Bird" that is engagin
## Anything UNCLEAR:
There are no unclear points.
'''
"""
d = OutputParser.parse_data_with_mapping(t_text1, OUTPUT_MAPPING)
import json

View file

@ -52,9 +52,11 @@ PAGE = """
</html>
"""
CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
'with a class "box".a link'
CONTENT = (
"This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered "
"Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div "
'with a class "box".a link'
)
def test_web_page():

View file

@ -1,6 +1,6 @@
from metagpt.utils import pycst
code = '''
code = """
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import overload
@ -24,7 +24,7 @@ class Person:
def greet(self):
return f"Hello, my name is {self.name} and I am {self.age} years old."
'''
"""
documented_code = '''
"""

View file

@ -29,7 +29,7 @@ def _paragraphs(n):
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
(_msgs(), "gpt-4-32k", "System", 4000, 14),
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
]
],
)
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
@ -42,7 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
]
],
)
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
@ -58,7 +58,7 @@ def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, r
("......", ".", 2, ["...", "..."]),
("......", ".", 3, ["..", "..", ".."]),
(".......", ".", 2, ["....", "..."]),
]
],
)
def test_split_paragraph(paragraph, sep, count, expected):
ret = split_paragraph(paragraph, sep, count)
@ -71,7 +71,7 @@ def test_split_paragraph(paragraph, sep, count, expected):
("Hello\\nWorld", "Hello\nWorld"),
("Hello\\tWorld", "Hello\tWorld"),
("Hello\\u0020World", "Hello World"),
]
],
)
def test_decode_unicode_escape(text, expected):
assert decode_unicode_escape(text) == expected