mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-20 15:38:09 +02:00
refactor: pre-commit run --all-files
This commit is contained in:
parent
d8adba99d4
commit
cda032948f
129 changed files with 812 additions and 831 deletions
|
|
@ -6,14 +6,14 @@
|
|||
@File : conftest.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
|
||||
class Context:
|
||||
|
|
|
|||
|
|
@ -311,12 +311,10 @@ TASKS = [
|
|||
"添加数据API:接受用户输入的文档库,对文档库进行索引\n- 使用MeiliSearch连接并添加文档库",
|
||||
"搜索API:接收用户输入的关键词,返回相关的搜索结果\n- 使用MeiliSearch连接并使用接口获得对应数据",
|
||||
"多条件筛选API:接收用户选择的筛选条件,返回符合条件的搜索结果。\n- 使用MeiliSearch进行筛选并返回符合条件的搜索结果",
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。"
|
||||
"智能推荐API:根据用户的搜索历史记录和搜索行为,推荐相关的搜索结果。",
|
||||
]
|
||||
|
||||
TASKS_2 = [
|
||||
"完成main.py的功能"
|
||||
]
|
||||
TASKS_2 = ["完成main.py的功能"]
|
||||
|
||||
SEARCH_CODE_SAMPLE = """
|
||||
import requests
|
||||
|
|
@ -460,7 +458,7 @@ if __name__ == '__main__':
|
|||
print('No results found.')
|
||||
'''
|
||||
|
||||
MEILI_CODE = '''import meilisearch
|
||||
MEILI_CODE = """import meilisearch
|
||||
from typing import List
|
||||
|
||||
|
||||
|
|
@ -496,9 +494,9 @@ if __name__ == '__main__':
|
|||
|
||||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
'''
|
||||
"""
|
||||
|
||||
MEILI_ERROR = '''/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
MEILI_ERROR = """/usr/local/bin/python3.9 /Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py
|
||||
Traceback (most recent call last):
|
||||
File "/Users/alexanderwu/git/metagpt/examples/search/meilisearch_index.py", line 44, in <module>
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
|
|
@ -506,7 +504,7 @@ Traceback (most recent call last):
|
|||
index = self.client.get_or_create_index(index_name)
|
||||
AttributeError: 'Client' object has no attribute 'get_or_create_index'
|
||||
|
||||
Process finished with exit code 1'''
|
||||
Process finished with exit code 1"""
|
||||
|
||||
MEILI_CODE_REFINED = """
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,18 +9,21 @@ from typing import List, Tuple
|
|||
|
||||
from metagpt.actions import ActionOutput
|
||||
|
||||
t_dict = {"Required Python third-party packages": "\"\"\"\nflask==1.1.2\npygame==2.0.1\n\"\"\"\n",
|
||||
"Required Other language third-party packages": "\"\"\"\nNo third-party packages required for other languages.\n\"\"\"\n",
|
||||
"Full API spec": "\"\"\"\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n '200':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n '200':\n description: A JSON object of the updated game state\n\"\"\"\n",
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."]],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?"}
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
|
|
@ -45,6 +48,6 @@ def test_create_model_class_with_mapping():
|
|||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -10,12 +10,7 @@ from metagpt.actions.azure_tts import AzureTTS
|
|||
|
||||
def test_azure_tts():
|
||||
azure_tts = AzureTTS("azure_tts")
|
||||
azure_tts.synthesize_speech(
|
||||
"zh-CN",
|
||||
"zh-CN-YunxiNeural",
|
||||
"Boy",
|
||||
"你好,我是卡卡",
|
||||
"output.wav")
|
||||
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "你好,我是卡卡", "output.wav")
|
||||
|
||||
# 运行需要先配置 SUBSCRIPTION_KEY
|
||||
# TODO: 这里如果要检验,还要额外加上对应的asr,才能确保前后生成是接近一致的,但现在还没有
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import pytest
|
|||
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code
|
||||
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
import ta
|
||||
|
|
@ -31,14 +30,18 @@ def get_expected_res():
|
|||
import ta
|
||||
|
||||
# 读取股票数据
|
||||
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
|
||||
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
|
||||
stock_data.head()
|
||||
# 计算简单移动平均线
|
||||
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
|
||||
stock_data[['Date', 'Close', 'SMA']].head()
|
||||
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
|
||||
stock_data[["Date", "Close", "SMA"]].head()
|
||||
# 计算布林带
|
||||
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
|
||||
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
|
||||
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
|
||||
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
|
||||
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
|
||||
)
|
||||
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
|
||||
return stock_data
|
||||
|
||||
|
||||
|
|
@ -46,9 +49,9 @@ def get_expected_res():
|
|||
async def test_clone_function():
|
||||
clone = CloneFunction()
|
||||
code = await clone.run(template_code, source_code)
|
||||
assert 'def ' in code
|
||||
stock_path = './tests/data/baba_stock.csv'
|
||||
df, msg = run_function_code(code, 'stock_indicator', stock_path)
|
||||
assert "def " in code
|
||||
stock_path = "./tests/data/baba_stock.csv"
|
||||
df, msg = run_function_code(code, "stock_indicator", stock_path)
|
||||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
|
|
|||
|
|
@ -144,12 +144,12 @@ Engineer
|
|||
---
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_error():
|
||||
|
||||
debug_error = DebugError("debug_error")
|
||||
|
||||
file_name, rewritten_code = await debug_error.run(context=EXAMPLE_MSG_CONTENT)
|
||||
|
||||
assert "class Player" in rewritten_code # rewrite the same class
|
||||
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
|
||||
assert "class Player" in rewritten_code # rewrite the same class
|
||||
assert "while self.score > 21" in rewritten_code # a key logic to rewrite to (original one is "if self.score > 12")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import pytest
|
|||
from metagpt.actions.detail_mining import DetailMining
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detail_mining():
|
||||
topic = "如何做一个生日蛋糕"
|
||||
|
|
@ -17,7 +18,6 @@ async def test_detail_mining():
|
|||
detail_mining = DetailMining("detail_mining")
|
||||
rsp = await detail_mining.run(topic=topic, record=record)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert '##OUTPUT' in rsp.content
|
||||
assert '蛋糕' in rsp.content
|
||||
|
||||
assert "##OUTPUT" in rsp.content
|
||||
assert "蛋糕" in rsp.content
|
||||
|
|
|
|||
|
|
@ -8,12 +8,11 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -22,7 +21,7 @@ from metagpt.actions.invoice_ocr import InvoiceOCR, GenerateTable, ReplyQuestion
|
|||
[
|
||||
"../../data/invoices/invoice-3.jpg",
|
||||
"../../data/invoices/invoice-4.zip",
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -35,18 +34,8 @@ async def test_invoice_ocr(invoice_path: str):
|
|||
@pytest.mark.parametrize(
|
||||
("invoice_path", "expected_result"),
|
||||
[
|
||||
(
|
||||
"../../data/invoices/invoice-1.pdf",
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": "412.00",
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
|
||||
],
|
||||
)
|
||||
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -59,9 +48,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("invoice_path", "query", "expected_result"),
|
||||
[
|
||||
("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")
|
||||
]
|
||||
[("../../data/invoices/invoice-1.pdf", "Invoicing date", "2023年02月03日")],
|
||||
)
|
||||
async def test_reply_question(invoice_path: str, query: dict, expected_result: str):
|
||||
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
|
||||
|
|
@ -69,4 +56,3 @@ async def test_reply_question(invoice_path: str, query: dict, expected_result: s
|
|||
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
|
||||
result = await ReplyQuestion().run(query=query, ocr_result=ocr_result)
|
||||
assert expected_result in result
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#
|
||||
from tests.metagpt.roles.ui_role import UIDesign
|
||||
|
||||
llm_resp= '''
|
||||
llm_resp = """
|
||||
# UI Design Description
|
||||
```The user interface for the snake game will be designed in a way that is simple, clean, and intuitive. The main elements of the game such as the game grid, snake, food, score, and game over message will be clearly defined and easy to understand. The game grid will be centered on the screen with the score displayed at the top. The game controls will be intuitive and easy to use. The design will be modern and minimalist with a pleasing color scheme.```
|
||||
|
||||
|
|
@ -98,12 +98,13 @@ body {
|
|||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def test_ui_design_parse_css():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
css = '''
|
||||
css = """
|
||||
body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
|
@ -160,14 +161,14 @@ def test_ui_design_parse_css():
|
|||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
font-size: 3em;
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==css
|
||||
"""
|
||||
assert ui_design_work.parse_css_code(context=llm_resp) == css
|
||||
|
||||
|
||||
def test_ui_design_parse_html():
|
||||
ui_design_work = UIDesign(name="UI design action")
|
||||
|
||||
html = '''
|
||||
html = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
|
|
@ -184,8 +185,5 @@ def test_ui_design_parse_html():
|
|||
<div class="game-over">Game Over</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
assert ui_design_work.parse_css_code(context=llm_resp)==html
|
||||
|
||||
|
||||
|
||||
"""
|
||||
assert ui_design_work.parse_css_code(context=llm_resp) == html
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ async def test_write_code():
|
|||
logger.info(code)
|
||||
|
||||
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
|
||||
assert 'def add' in code
|
||||
assert 'return' in code
|
||||
assert "def add" in code
|
||||
assert "return" in code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_directly():
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
|
||||
llm = LLM()
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_docstring import WriteDocstring
|
||||
|
||||
code = '''
|
||||
code = """
|
||||
def add_numbers(a: int, b: int):
|
||||
return a + b
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ class Person:
|
|||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -25,7 +25,7 @@ class Person:
|
|||
("numpy", "Parameters"),
|
||||
("sphinx", ":param name:"),
|
||||
],
|
||||
ids=["google", "numpy", "sphinx"]
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
|
|
|
|||
|
|
@ -9,14 +9,11 @@ from typing import Dict
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.write_tutorial import WriteDirectory, WriteContent
|
||||
from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("English", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory(language: str, topic: str):
|
||||
ret = await WriteDirectory(language=language).run(topic=topic)
|
||||
assert isinstance(ret, dict)
|
||||
|
|
@ -30,7 +27,7 @@ async def test_write_directory(language: str, topic: str):
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})]
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content(language: str, topic: str, directory: Dict):
|
||||
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore('sample_collection_1')
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(["This is document1", "This is document2"],
|
||||
[{"source": "google-docs"}, {"source": "notion"}],
|
||||
["doc1", "doc2"])
|
||||
document_store.write(
|
||||
["This is document1", "This is document2"], [{"source": "google-docs"}, {"source": "notion"}], ["doc1", "doc2"]
|
||||
)
|
||||
|
||||
# 使用 add 方法添加一个文档
|
||||
document_store.add("This is document3", {"source": "notion"}, "doc3")
|
||||
|
|
|
|||
|
|
@ -39,11 +39,11 @@ user: 没有了
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_faiss_store_search():
|
||||
store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
|
||||
store.add(['油皮洗面奶'])
|
||||
store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
|
||||
store.add(["油皮洗面奶"])
|
||||
role = Sales(store=store)
|
||||
|
||||
queries = ['油皮洗面奶', '介绍下欧莱雅的']
|
||||
queries = ["油皮洗面奶", "介绍下欧莱雅的"]
|
||||
for query in queries:
|
||||
rsp = await role.run(query)
|
||||
assert rsp
|
||||
|
|
@ -60,7 +60,10 @@ def customer_service():
|
|||
async def test_faiss_store_customer_service():
|
||||
allq = [
|
||||
# ["我的餐怎么两小时都没到", "退货吧"],
|
||||
["你好收不到取餐码,麻烦帮我开箱", "14750187158", ]
|
||||
[
|
||||
"你好收不到取餐码,麻烦帮我开箱",
|
||||
"14750187158",
|
||||
]
|
||||
]
|
||||
role = customer_service()
|
||||
for queries in allq:
|
||||
|
|
@ -71,4 +74,4 @@ async def test_faiss_store_customer_service():
|
|||
|
||||
def test_faiss_store_no_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
FaissStore(DATA_PATH / 'wtf.json')
|
||||
FaissStore(DATA_PATH / "wtf.json")
|
||||
|
|
|
|||
|
|
@ -5,27 +5,33 @@
|
|||
@Author : unkn-wn (Leon Yee)
|
||||
@File : test_lancedb_store.py
|
||||
"""
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
import pytest
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.document_store.lancedb_store import LanceStore
|
||||
|
||||
|
||||
@pytest
|
||||
def test_lance_store():
|
||||
|
||||
# This simply establishes the connection to the database, so we can drop the table if it exists
|
||||
store = LanceStore('test')
|
||||
store = LanceStore("test")
|
||||
|
||||
store.drop('test')
|
||||
store.drop("test")
|
||||
|
||||
store.write(data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"])
|
||||
store.write(
|
||||
data=[[random.random() for _ in range(100)] for _ in range(2)],
|
||||
metadatas=[{"source": "google-docs"}, {"source": "notion"}],
|
||||
ids=["doc1", "doc2"],
|
||||
)
|
||||
|
||||
store.add(data=[random.random() for _ in range(100)], metadata={"source": "notion"}, _id="doc3")
|
||||
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3)
|
||||
assert(len(result) == 3)
|
||||
assert len(result) == 3
|
||||
|
||||
store.delete("doc2")
|
||||
result = store.search([random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric='cosine')
|
||||
assert(len(result) == 1)
|
||||
result = store.search(
|
||||
[random.random() for _ in range(100)], n_results=3, where="source = 'notion'", metric="cosine"
|
||||
)
|
||||
assert len(result) == 1
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import numpy as np
|
|||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
from metagpt.logs import logger
|
||||
|
||||
book_columns = {'idx': int, 'name': str, 'desc': str, 'emb': np.ndarray, 'price': float}
|
||||
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
|
||||
book_data = [
|
||||
[i for i in range(10)],
|
||||
[f"book-{i}" for i in range(10)],
|
||||
|
|
@ -25,12 +25,12 @@ book_data = [
|
|||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
milvus_store.drop('Book')
|
||||
milvus_store.create_collection('Book', book_columns)
|
||||
milvus_store.drop("Book")
|
||||
milvus_store.create_collection("Book", book_columns)
|
||||
milvus_store.add(book_data)
|
||||
milvus_store.build_index('emb')
|
||||
milvus_store.build_index("emb")
|
||||
milvus_store.load_collection()
|
||||
|
||||
results = milvus_store.search([[1.0, 1.0]], field='emb')
|
||||
results = milvus_store.search([[1.0, 1.0]], field="emb")
|
||||
logger.info(results)
|
||||
assert results
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ random.seed(seed_value)
|
|||
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
|
||||
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
|
||||
)
|
||||
PointStruct(id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10})
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
|
||||
|
|
@ -57,9 +55,7 @@ def test_milvus_store():
|
|||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
)
|
||||
assert results[0]["id"] == 8
|
||||
assert results[0]["score"] == 0.9100373450784073
|
||||
|
|
@ -68,9 +64,7 @@ def test_milvus_store():
|
|||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
query_filter=Filter(must=[FieldCondition(key="rand_number", range=Range(gte=8))]),
|
||||
return_vector=True,
|
||||
)
|
||||
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def test_skill_manager():
|
|||
|
||||
rsp = manager.retrieve_skill("写测试用例")
|
||||
logger.info(rsp)
|
||||
assert rsp[0] == 'WriteTest'
|
||||
assert rsp[0] == "WriteTest"
|
||||
|
||||
rsp = manager.retrieve_skill_scored("写PRD")
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
|
|
@ -14,25 +14,25 @@ def test_ltm_search():
|
|||
openai_api_key = CONFIG.openai_api_key
|
||||
assert len(openai_api_key) > 20
|
||||
|
||||
role_id = 'UTUserLtm(Product Manager)'
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
rc = RoleContext(watch=[BossRequirement])
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
||||
idea = 'Write a cli snake game'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
|
||||
news = ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
|
@ -47,8 +47,8 @@ def test_ltm_search():
|
|||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = 'Write a Battle City'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a Battle City"
|
||||
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,16 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import BossRequirement, WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import BossRequirement
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
idea = 'Write a cli snake game'
|
||||
role_id = 'UTUser1(Product Manager)'
|
||||
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
|
||||
idea = "Write a cli snake game"
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -23,13 +22,13 @@ def test_idea_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = 'Write a game of cli snake'
|
||||
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
|
||||
sim_idea = "Write a game of cli snake"
|
||||
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_idea = 'Write a 2048 web game'
|
||||
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
|
||||
new_idea = "Write a 2048 web game"
|
||||
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
|
|
@ -38,22 +37,15 @@ def test_idea_message():
|
|||
|
||||
|
||||
def test_actionout_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = 'UTUser2(Architect)'
|
||||
content = 'The boss has requested the creation of a command-line interface (CLI) snake game'
|
||||
message = Message(content=content,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The boss has requested the creation of a command-line interface (CLI) snake game"
|
||||
message = Message(
|
||||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -62,19 +54,13 @@ def test_actionout_message():
|
|||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = 'The request is command-line interface (CLI) snake game'
|
||||
sim_message = Message(content=sim_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
sim_conent = "The request is command-line interface (CLI) snake game"
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
|
||||
new_conent = 'Incorporate basic features of a snake game such as scoring and increasing difficulty'
|
||||
new_message = Message(content=new_conent,
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD)
|
||||
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ from metagpt.schema import Message
|
|||
|
||||
|
||||
def test_message():
|
||||
message = Message(role='user', content='wtf')
|
||||
assert 'role' in message.to_dict()
|
||||
assert 'user' in str(message)
|
||||
message = Message(role="user", content="wtf")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ def test_message():
|
|||
llm = SparkAPI()
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask('写一篇五百字的日记')
|
||||
result = llm.ask("写一篇五百字的日记")
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ PRD = '''## 原始需求
|
|||
```
|
||||
'''
|
||||
|
||||
SYSTEM_DESIGN = '''## Python package name
|
||||
SYSTEM_DESIGN = """## Python package name
|
||||
```python
|
||||
"smart_search_engine"
|
||||
```
|
||||
|
|
@ -149,10 +149,10 @@ sequenceDiagram
|
|||
S-->>SE: return summary
|
||||
SE-->>M: return summary
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
TASKS = '''## Logic Analysis
|
||||
TASKS = """## Logic Analysis
|
||||
|
||||
在这个项目中,所有的模块都依赖于“SearchEngine”类,这是主入口,其他的模块(Index、Ranking和Summary)都通过它交互。另外,"Index"类又依赖于"KnowledgeBase"类,因为它需要从知识库中获取数据。
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ task_list = [
|
|||
]
|
||||
```
|
||||
这个任务列表首先定义了最基础的模块,然后是依赖这些模块的模块,最后是辅助模块。可以根据团队的能力和资源,同时开发多个任务,只要满足依赖关系。例如,在开发"search.py"之前,可以同时开发"knowledge_base.py"、"index.py"、"ranking.py"和"summary.py"。
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
TASKS_TOMATO_CLOCK = '''## Required Python third-party packages: Provided in requirements.txt format
|
||||
|
|
@ -224,30 +224,30 @@ task_list = [
|
|||
TASK = """smart_search_engine/knowledge_base.py"""
|
||||
|
||||
STRS_FOR_PARSING = [
|
||||
"""
|
||||
"""
|
||||
## 1
|
||||
```python
|
||||
a
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
##2
|
||||
```python
|
||||
"a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 3
|
||||
```python
|
||||
a = "a"
|
||||
```
|
||||
""",
|
||||
"""
|
||||
"""
|
||||
## 4
|
||||
```python
|
||||
a = 'a'
|
||||
```
|
||||
"""
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,13 +35,13 @@ def test_parse_str():
|
|||
for idx, i in enumerate(STRS_FOR_PARSING):
|
||||
text = CodeParser.parse_str(f"{idx+1}", i)
|
||||
# logger.info(text)
|
||||
assert text == 'a'
|
||||
assert text == "a"
|
||||
|
||||
|
||||
def test_parse_blocks():
|
||||
tasks = CodeParser.parse_blocks(TASKS)
|
||||
logger.info(tasks.keys())
|
||||
assert 'Task list' in tasks.keys()
|
||||
assert "Task list" in tasks.keys()
|
||||
|
||||
|
||||
target_list = [
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.roles.invoice_ocr_assistant import InvoiceOCRAssistant
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -24,82 +24,39 @@ from metagpt.schema import Message
|
|||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-1.pdf"),
|
||||
Path("../../../data/invoice_table/invoice-1.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-2.png"),
|
||||
Path("../../../data/invoice_table/invoice-2.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-3.jpg"),
|
||||
Path("../../../data/invoice_table/invoice-3.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
|
||||
),
|
||||
(
|
||||
"Invoicing date",
|
||||
Path("../../data/invoices/invoice-4.zip"),
|
||||
Path("../../../data/invoice_table/invoice-4.xlsx"),
|
||||
[
|
||||
{
|
||||
"收款人": "小明",
|
||||
"城市": "深圳市",
|
||||
"总费用/元": 412.00,
|
||||
"开票日期": "2023年02月03日"
|
||||
},
|
||||
{
|
||||
"收款人": "铁头",
|
||||
"城市": "广州市",
|
||||
"总费用/元": 898.00,
|
||||
"开票日期": "2023年03月17日"
|
||||
},
|
||||
{
|
||||
"收款人": "夏天",
|
||||
"城市": "福州市",
|
||||
"总费用/元": 2462.00,
|
||||
"开票日期": "2023年08月26日"
|
||||
}
|
||||
]
|
||||
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
|
||||
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
|
||||
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
|
||||
],
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str,
|
||||
invoice_path: Path,
|
||||
invoice_table_path: Path,
|
||||
expected_result: list[dict]
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
|
||||
):
|
||||
invoice_path = Path.cwd() / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
await role.run(Message(
|
||||
content=query,
|
||||
instruct_content={"file_path": invoice_path}
|
||||
))
|
||||
await role.run(Message(content=query, instruct_content={"file_path": invoice_path}))
|
||||
invoice_table_path = Path.cwd() / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
dict_result = df.to_dict(orient='records')
|
||||
dict_result = df.to_dict(orient="records")
|
||||
assert dict_result == expected_result
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,12 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
return (
|
||||
'["Dataiku machine learning platform", "DataRobot AI platform comparison", '
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
)
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
return "[1,2]"
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
|
|
|
|||
|
|
@ -12,10 +12,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("language", "topic"),
|
||||
[("Chinese", "Write a tutorial about Python")]
|
||||
)
|
||||
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")])
|
||||
async def test_tutorial_assistant(language: str, topic: str):
|
||||
topic = "Write a tutorial about MySQL"
|
||||
role = TutorialAssistant(language=language)
|
||||
|
|
@ -24,4 +21,4 @@ async def test_tutorial_assistant(language: str, topic: str):
|
|||
title = filename.split("/")[-1].split(".")[0]
|
||||
async with aiofiles.open(filename, mode="r") as reader:
|
||||
content = await reader.read()
|
||||
assert content.startswith(f"# {title}")
|
||||
assert content.startswith(f"# {title}")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@
|
|||
# @Date : 2023/7/22 02:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
#
|
||||
from metagpt.software_company import SoftwareCompany
|
||||
from metagpt.roles import ProductManager
|
||||
|
||||
from metagpt.software_company import SoftwareCompany
|
||||
from tests.metagpt.roles.ui_role import UI
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.logs import logger
|
|||
@pytest.mark.usefixtures("llm_api")
|
||||
class TestGPT:
|
||||
def test_llm_api_ask(self, llm_api):
|
||||
answer = llm_api.ask('hello chatgpt')
|
||||
answer = llm_api.ask("hello chatgpt")
|
||||
assert len(answer) > 0
|
||||
|
||||
# def test_gptapi_ask_batch(self, llm_api):
|
||||
|
|
@ -22,22 +22,22 @@ class TestGPT:
|
|||
# assert len(answer) > 0
|
||||
|
||||
def test_llm_api_ask_code(self, llm_api):
|
||||
answer = llm_api.ask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world'])
|
||||
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask('hello chatgpt')
|
||||
answer = await llm_api.aask("hello chatgpt")
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask_code(self, llm_api):
|
||||
answer = await llm_api.aask_code(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world'])
|
||||
answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_costs(self, llm_api):
|
||||
await llm_api.aask('hello chatgpt')
|
||||
await llm_api.aask("hello chatgpt")
|
||||
costs = llm_api.get_costs()
|
||||
logger.info(costs)
|
||||
assert costs.total_cost > 0
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ def llm():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask(llm):
|
||||
assert len(await llm.aask('hello world')) > 0
|
||||
assert len(await llm.aask("hello world")) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask_batch(llm):
|
||||
assert len(await llm.aask_batch(['hi', 'write python hello world.'])) > 0
|
||||
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_acompletion(llm):
|
||||
hello_msg = [{'role': 'user', 'content': 'hello'}]
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
assert len(await llm.acompletion(hello_msg)) > 0
|
||||
assert len(await llm.acompletion_batch([hello_msg])) > 0
|
||||
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
|
||||
|
|
|
|||
|
|
@ -11,26 +11,26 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe
|
|||
|
||||
|
||||
def test_message():
|
||||
msg = Message(role='User', content='WTF')
|
||||
assert msg.to_dict()['role'] == 'User'
|
||||
assert 'User' in str(msg)
|
||||
msg = Message(role="User", content="WTF")
|
||||
assert msg.to_dict()["role"] == "User"
|
||||
assert "User" in str(msg)
|
||||
|
||||
|
||||
def test_all_messages():
|
||||
test_content = 'test_message'
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
for msg in msgs:
|
||||
assert msg.content == test_content
|
||||
|
||||
|
||||
def test_raw_message():
|
||||
msg = RawMessage(role='user', content='raw')
|
||||
assert msg['role'] == 'user'
|
||||
assert msg['content'] == 'raw'
|
||||
msg = RawMessage(role="user", content="raw")
|
||||
assert msg["role"] == "user"
|
||||
assert msg["content"] == "raw"
|
||||
with pytest.raises(KeyError):
|
||||
assert msg['1'] == 1, "KeyError: '1'"
|
||||
assert msg["1"] == 1, "KeyError: '1'"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,6 @@ from metagpt.roles import Role
|
|||
|
||||
|
||||
def test_role_desc():
|
||||
i = Role(profile='Sales', desc='Best Seller')
|
||||
assert i.profile == 'Sales'
|
||||
assert i._setting.desc == 'Best Seller'
|
||||
i = Role(profile="Sales", desc="Best Seller")
|
||||
assert i.profile == "Sales"
|
||||
assert i._setting.desc == "Best Seller"
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
|||
|
||||
|
||||
def test_messages():
|
||||
test_content = 'test_message'
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
text = str(msgs)
|
||||
roles = ['user', 'system', 'assistant', 'QA']
|
||||
roles = ["user", "system", "assistant", "QA"]
|
||||
assert all([i in text for i in roles])
|
||||
|
|
|
|||
|
|
@ -1,23 +1,22 @@
|
|||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from tests.data import sales_desc, store_desc
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.code_interpreter import OpenCodeInterpreter, OpenInterpreterDecorator
|
||||
|
||||
|
||||
logger.add('./tests/data/test_ci.log')
|
||||
logger.add("./tests/data/test_ci.log")
|
||||
stock = "./tests/data/baba_stock.csv"
|
||||
|
||||
|
||||
# TODO: 需要一种表格数据格式,能够支持schame管理的,标注字段类型和字段含义。
|
||||
class CreateStockIndicators(Action):
|
||||
@OpenInterpreterDecorator(save_code=True, code_file_path="./tests/data/stock_indicators.py")
|
||||
async def run(self, stock_path: str, indicators=['Simple Moving Average', 'BollingerBands']) -> pd.DataFrame:
|
||||
async def run(self, stock_path: str, indicators=["Simple Moving Average", "BollingerBands"]) -> pd.DataFrame:
|
||||
"""对stock_path中的股票数据, 使用pandas和ta计算indicators中的技术指标, 返回带有技术指标的股票数据,不需要去除空值, 不需要安装任何包;
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
指标生成对应的三列: SMA, BB_upper, BB_lower
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
@ -25,18 +24,20 @@ class CreateStockIndicators(Action):
|
|||
@pytest.mark.asyncio
|
||||
async def test_actions():
|
||||
# 计算指标
|
||||
indicators = ['Simple Moving Average', 'BollingerBands']
|
||||
indicators = ["Simple Moving Average", "BollingerBands"]
|
||||
stocker = CreateStockIndicators()
|
||||
df, msg = await stocker.run(stock, indicators=indicators)
|
||||
assert isinstance(df, pd.DataFrame)
|
||||
assert 'Close' in df.columns
|
||||
assert 'Date' in df.columns
|
||||
assert "Close" in df.columns
|
||||
assert "Date" in df.columns
|
||||
# 将df保存为文件,将文件路径传入到下一个action
|
||||
df_path = './tests/data/stock_indicators.csv'
|
||||
df_path = "./tests/data/stock_indicators.csv"
|
||||
df.to_csv(df_path)
|
||||
assert Path(df_path).is_file()
|
||||
# 可视化指标结果
|
||||
figure_path = './tests/data/figure_ci.png'
|
||||
figure_path = "./tests/data/figure_ci.png"
|
||||
ci_ploter = OpenCodeInterpreter()
|
||||
ci_ploter.chat(f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。")
|
||||
ci_ploter.chat(
|
||||
f"使用seaborn对{df_path}中与股票布林带有关的数据列的Date, Close, SMA, BB_upper(布林带上界), BB_lower(布林带下界)进行可视化, 可视化图片保存在{figure_path}中。不需要任何指标计算,把Date列转换为日期类型。要求图片优美,BB_upper, BB_lower之间使用合适的颜色填充。"
|
||||
)
|
||||
assert Path(figure_path).is_file()
|
||||
|
|
|
|||
|
|
@ -20,8 +20,9 @@ from metagpt.tools.prompt_writer import (
|
|||
@pytest.mark.usefixtures("llm_api")
|
||||
def test_gpt_prompt_generator(llm_api):
|
||||
generator = GPTPromptGenerator()
|
||||
example = "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " \
|
||||
"品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
|
||||
example = (
|
||||
"商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
|
||||
)
|
||||
|
||||
results = llm_api.ask_batch(generator.gen(example))
|
||||
logger.info(results)
|
||||
|
|
@ -46,7 +47,7 @@ def test_enron_template(llm_api):
|
|||
|
||||
results = template.gen(subj)
|
||||
assert len(results) > 0
|
||||
assert any("Write an email with the subject \"Meeting Agenda\"." in r for r in results)
|
||||
assert any('Write an email with the subject "Meeting Agenda".' in r for r in results)
|
||||
|
||||
|
||||
def test_beagec_template():
|
||||
|
|
@ -54,5 +55,6 @@ def test_beagec_template():
|
|||
|
||||
results = template.gen()
|
||||
assert len(results) > 0
|
||||
assert any("Edit and revise this document to improve its grammar, vocabulary, spelling, and style."
|
||||
in r for r in results)
|
||||
assert any(
|
||||
"Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#
|
||||
import os
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine, WORKSPACE_ROOT
|
||||
from metagpt.tools.sd_engine import WORKSPACE_ROOT, SDEngine
|
||||
|
||||
|
||||
def test_sd_engine_init():
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
|
|||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
rets = [
|
||||
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
|
||||
]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
|
|
@ -34,10 +36,14 @@ class MockSearchEnine:
|
|||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
|
||||
async def test_search_engine(
|
||||
search_engine_typpe,
|
||||
run_func,
|
||||
max_results,
|
||||
as_string,
|
||||
):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import pytest
|
|||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_meilisearch import DataSource, MeilisearchEngine
|
||||
|
||||
MASTER_KEY = '116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk'
|
||||
MASTER_KEY = "116Qavl2qpCYNEJNv5-e0RC9kncev1nr1gt7ybEGVLk"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -29,7 +29,7 @@ def test_meilisearch(search_engine_server):
|
|||
search_engine = MeilisearchEngine(url="http://localhost:7700", token=MASTER_KEY)
|
||||
|
||||
# 假设有一个名为"books"的数据源,包含要添加的文档库
|
||||
books_data_source = DataSource(name='books', url='https://example.com/books')
|
||||
books_data_source = DataSource(name="books", url="https://example.com/books")
|
||||
|
||||
# 假设有一个名为"documents"的文档库,包含要添加的文档
|
||||
documents = [
|
||||
|
|
@ -43,4 +43,4 @@ def test_meilisearch(search_engine_server):
|
|||
|
||||
# 添加文档库到搜索引擎
|
||||
search_engine.add_documents(books_data_source, documents)
|
||||
logger.info(search_engine.search('Book 1'))
|
||||
logger.info(search_engine.search("Book 1"))
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -16,7 +16,7 @@ from metagpt.tools.translator import Translator
|
|||
def test_translate(llm_api):
|
||||
poetries = [
|
||||
("Let life be beautiful like summer flowers", "花"),
|
||||
("The ancient Chinese poetries are all songs.", "中国")
|
||||
("The ancient Chinese poetries are all songs.", "中国"),
|
||||
]
|
||||
for i, j in poetries:
|
||||
prompt = Translator.translate_prompt(i)
|
||||
|
|
|
|||
|
|
@ -16,8 +16,12 @@ class TestUTWriter:
|
|||
tags = ["测试"] # "智能合同导入", "律师审查", "ai合同审查", "草拟合同&律师在线审查", "合同审批", "履约管理", "签约公司"]
|
||||
# 这里在文件中手动加入了两个测试标签的API
|
||||
|
||||
utg = UTGenerator(swagger_file=swagger_file, ut_py_path=UT_PY_PATH, questions_path=API_QUESTIONS_PATH,
|
||||
template_prefix=YFT_PROMPT_PREFIX)
|
||||
utg = UTGenerator(
|
||||
swagger_file=swagger_file,
|
||||
ut_py_path=UT_PY_PATH,
|
||||
questions_path=API_QUESTIONS_PATH,
|
||||
template_prefix=YFT_PROMPT_PREFIX,
|
||||
)
|
||||
ret = utg.generate_ut(include_tags=tags)
|
||||
# 后续加入对文件生成内容与数量的检验
|
||||
assert ret
|
||||
|
|
|
|||
|
|
@ -131,10 +131,10 @@ class TestCodeParser:
|
|||
def test_parse_file_list(self, parser, text):
|
||||
result = parser.parse_file_list("Task list", text)
|
||||
print(result)
|
||||
assert result == ['task1', 'task2']
|
||||
assert result == ["task1", "task2"]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
t = TestCodeParser()
|
||||
t.test_parse_file_list(CodeParser(), t_text)
|
||||
# TestCodeParser.test_parse_file_list()
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from metagpt.const import get_project_root
|
|||
class TestGetProjectRoot:
|
||||
def change_etc_dir(self):
|
||||
# current_directory = Path.cwd()
|
||||
abs_root = '/etc'
|
||||
abs_root = "/etc"
|
||||
os.chdir(abs_root)
|
||||
|
||||
def test_get_project_root(self):
|
||||
project_root = get_project_root()
|
||||
assert project_root.name == 'metagpt'
|
||||
assert project_root.name == "metagpt"
|
||||
|
||||
def test_get_root_exception(self):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
|
|
|||
|
|
@ -20,12 +20,12 @@ def test_config_class_is_singleton():
|
|||
def test_config_class_get_key_exception():
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config = Config()
|
||||
config.get('wtf')
|
||||
config.get("wtf")
|
||||
assert str(exc_info.value) == "Key 'wtf' not found in environment variables or in the YAML file"
|
||||
|
||||
|
||||
def test_config_yaml_file_not_exists():
|
||||
config = Config('wtf.yaml')
|
||||
config = Config("wtf.yaml")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
config.get('OPENAI_BASE_URL')
|
||||
config.get("OPENAI_BASE_URL")
|
||||
assert str(exc_info.value) == "Key 'OPENAI_BASE_URL' not found in environment variables or in the YAML file"
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from metagpt.provider.openai_api import OpenAIGPTAPI
|
|||
|
||||
|
||||
async def try_hello(api):
|
||||
batch = [[{'role': 'user', 'content': 'hello'}]]
|
||||
batch = [[{"role": "user", "content": "hello"}]]
|
||||
results = await api.acompletion_batch_text(batch)
|
||||
return results
|
||||
|
||||
|
||||
async def aask_batch(api: OpenAIGPTAPI):
|
||||
results = await api.aask_batch(['hi', 'write python hello world.'])
|
||||
results = await api.aask_batch(["hi", "write python hello world."])
|
||||
logger.info(results)
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -15,12 +15,11 @@ from metagpt.utils.file import File
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("root_path", "filename", "content"),
|
||||
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")]
|
||||
[(Path("/code/MetaGPT/data/tutorial_docx/2023-09-07_17-05-20"), "test.md", "Hello World!")],
|
||||
)
|
||||
async def test_write_and_read_file(root_path: Path, filename: str, content: bytes):
|
||||
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode('utf-8'))
|
||||
full_file_name = await File.write(root_path=root_path, filename=filename, content=content.encode("utf-8"))
|
||||
assert isinstance(full_file_name, Path)
|
||||
assert root_path / filename == full_file_name
|
||||
file_data = await File.read(full_file_name)
|
||||
assert file_data.decode("utf-8") == content
|
||||
|
||||
|
|
|
|||
|
|
@ -14,17 +14,17 @@ from metagpt.utils.common import OutputParser
|
|||
|
||||
def test_parse_blocks():
|
||||
test_text = "##block1\nThis is block 1.\n##block2\nThis is block 2."
|
||||
expected_result = {'block1': 'This is block 1.', 'block2': 'This is block 2.'}
|
||||
expected_result = {"block1": "This is block 1.", "block2": "This is block 2."}
|
||||
assert OutputParser.parse_blocks(test_text) == expected_result
|
||||
|
||||
|
||||
def test_parse_code():
|
||||
test_text = "```python\nprint('Hello, world!')```"
|
||||
expected_result = "print('Hello, world!')"
|
||||
assert OutputParser.parse_code(test_text, 'python') == expected_result
|
||||
assert OutputParser.parse_code(test_text, "python") == expected_result
|
||||
|
||||
with pytest.raises(Exception):
|
||||
OutputParser.parse_code(test_text, 'java')
|
||||
OutputParser.parse_code(test_text, "java")
|
||||
|
||||
|
||||
def test_parse_python_code():
|
||||
|
|
@ -45,13 +45,13 @@ def test_parse_python_code():
|
|||
|
||||
def test_parse_str():
|
||||
test_text = "name = 'Alice'"
|
||||
expected_result = 'Alice'
|
||||
expected_result = "Alice"
|
||||
assert OutputParser.parse_str(test_text) == expected_result
|
||||
|
||||
|
||||
def test_parse_file_list():
|
||||
test_text = "files=['file1', 'file2', 'file3']"
|
||||
expected_result = ['file1', 'file2', 'file3']
|
||||
expected_result = ["file1", "file2", "file3"]
|
||||
assert OutputParser.parse_file_list(test_text) == expected_result
|
||||
|
||||
with pytest.raises(Exception):
|
||||
|
|
@ -60,7 +60,7 @@ def test_parse_file_list():
|
|||
|
||||
def test_parse_data():
|
||||
test_data = "##block1\n```python\nprint('Hello, world!')\n```\n##block2\nfiles=['file1', 'file2', 'file3']"
|
||||
expected_result = {'block1': "print('Hello, world!')", 'block2': ['file1', 'file2', 'file3']}
|
||||
expected_result = {"block1": "print('Hello, world!')", "block2": ["file1", "file2", "file3"]}
|
||||
assert OutputParser.parse_data(test_data) == expected_result
|
||||
|
||||
|
||||
|
|
@ -103,9 +103,11 @@ def test_parse_data():
|
|||
None,
|
||||
Exception,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception):
|
||||
def test_extract_struct(
|
||||
text: str, data_type: Union[type(list), type(dict)], parsed_data: Union[list, dict], expected_exception
|
||||
):
|
||||
def case():
|
||||
resp = OutputParser.extract_struct(text, data_type)
|
||||
assert resp == parsed_data
|
||||
|
|
@ -117,7 +119,7 @@ def test_extract_struct(text: str, data_type: Union[type(list), type(dict)], par
|
|||
case()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
t_text = '''
|
||||
## Required Python third-party packages
|
||||
```python
|
||||
|
|
@ -216,7 +218,7 @@ We need clarification on how the high score should be stored. Should it persist
|
|||
"Requirement Pool": (List[Tuple[str, str]], ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
t_text1 = '''## Original Requirements:
|
||||
t_text1 = """## Original Requirements:
|
||||
|
||||
The boss wants to create a web-based version of the game "Fly Bird".
|
||||
|
||||
|
|
@ -284,7 +286,7 @@ The product should be a web-based version of the game "Fly Bird" that is engagin
|
|||
## Anything UNCLEAR:
|
||||
|
||||
There are no unclear points.
|
||||
'''
|
||||
"""
|
||||
d = OutputParser.parse_data_with_mapping(t_text1, OUTPUT_MAPPING)
|
||||
import json
|
||||
|
||||
|
|
|
|||
|
|
@ -52,9 +52,11 @@ PAGE = """
|
|||
</html>
|
||||
"""
|
||||
|
||||
CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
|
||||
'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
|
||||
'with a class "box".a link'
|
||||
CONTENT = (
|
||||
"This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered "
|
||||
"Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div "
|
||||
'with a class "box".a link'
|
||||
)
|
||||
|
||||
|
||||
def test_web_page():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from metagpt.utils import pycst
|
||||
|
||||
code = '''
|
||||
code = """
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import overload
|
||||
|
|
@ -24,7 +24,7 @@ class Person:
|
|||
|
||||
def greet(self):
|
||||
return f"Hello, my name is {self.name} and I am {self.age} years old."
|
||||
'''
|
||||
"""
|
||||
|
||||
documented_code = '''
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def _paragraphs(n):
|
|||
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
|
||||
(_msgs(), "gpt-4-32k", "System", 4000, 14),
|
||||
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
|
|
@ -42,7 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
|
|||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
|
|
@ -58,7 +58,7 @@ def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, r
|
|||
("......", ".", 2, ["...", "..."]),
|
||||
("......", ".", 3, ["..", "..", ".."]),
|
||||
(".......", ".", 2, ["....", "..."]),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_split_paragraph(paragraph, sep, count, expected):
|
||||
ret = split_paragraph(paragraph, sep, count)
|
||||
|
|
@ -71,7 +71,7 @@ def test_split_paragraph(paragraph, sep, count, expected):
|
|||
("Hello\\nWorld", "Hello\nWorld"),
|
||||
("Hello\\tWorld", "Hello\tWorld"),
|
||||
("Hello\\u0020World", "Hello World"),
|
||||
]
|
||||
],
|
||||
)
|
||||
def test_decode_unicode_escape(text, expected):
|
||||
assert decode_unicode_escape(text) == expected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue