mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 11:56:24 +02:00
Merge branch 'feature/teacher' into feature/fork_meta_role
This commit is contained in:
commit
145ffc7048
66 changed files with 2093 additions and 547 deletions
|
|
@ -6,24 +6,23 @@
|
|||
@File : test_run_code.py
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
from metagpt.actions.run_code import RunCode
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_text():
|
||||
action = RunCode()
|
||||
result, errs = await RunCode.run_text('result = 1 + 1')
|
||||
result, errs = await RunCode.run_text("result = 1 + 1")
|
||||
assert result == 2
|
||||
assert errs == ""
|
||||
|
||||
result, errs = await RunCode.run_text('result = 1 / 0')
|
||||
result, errs = await RunCode.run_text("result = 1 / 0")
|
||||
assert result == ""
|
||||
assert "ZeroDivisionError" in errs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_script():
|
||||
action = RunCode()
|
||||
|
||||
# Successful command
|
||||
out, err = await RunCode.run_script(".", command=["echo", "Hello World"])
|
||||
assert out.strip() == "Hello World"
|
||||
|
|
@ -33,6 +32,7 @@ async def test_run_script():
|
|||
out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"])
|
||||
assert "ZeroDivisionError" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
action = RunCode()
|
||||
|
|
@ -47,10 +47,11 @@ async def test_run():
|
|||
test_file_name="",
|
||||
command=["echo", "Hello World"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[]
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "PASS" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_failure():
|
||||
action = RunCode()
|
||||
|
|
@ -65,6 +66,6 @@ async def test_run_failure():
|
|||
test_file_name="",
|
||||
command=["python", "-c", "print(1/0)"],
|
||||
working_directory=".",
|
||||
additional_python_paths=[]
|
||||
additional_python_paths=[],
|
||||
)
|
||||
assert "FAIL" in result
|
||||
assert "FAIL" in result
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code_review import WriteCodeReview
|
||||
from metagpt.logs import logger
|
||||
from tests.metagpt.actions.mock import SEARCH_CODE_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,11 +18,7 @@ def add(a, b):
|
|||
"""
|
||||
# write_code_review = WriteCodeReview("write_code_review")
|
||||
|
||||
code = await WriteCodeReview().run(
|
||||
context="编写一个从a加b的函数,返回a+b",
|
||||
code=code,
|
||||
filename="math.py"
|
||||
)
|
||||
code = await WriteCodeReview().run(context="编写一个从a加b的函数,返回a+b", code=code, filename="math.py")
|
||||
|
||||
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
|
||||
assert isinstance(code, str)
|
||||
|
|
@ -33,6 +27,7 @@ def add(a, b):
|
|||
captured = capfd.readouterr()
|
||||
print(f"输出内容: {captured.out}")
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_write_code_review_directly():
|
||||
# code = SEARCH_CODE_SAMPLE
|
||||
|
|
|
|||
77
tests/metagpt/document_store/test_qdrant_store.py
Normal file
77
tests/metagpt/document_store/test_qdrant_store.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/6/11 21:08
|
||||
@Author : hezhaozhao
|
||||
@File : test_qdrant_store.py
|
||||
"""
|
||||
import random
|
||||
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
FieldCondition,
|
||||
Filter,
|
||||
PointStruct,
|
||||
Range,
|
||||
VectorParams,
|
||||
)
|
||||
|
||||
from metagpt.document_store.qdrant_store import QdrantConnection, QdrantStore
|
||||
|
||||
seed_value = 42
|
||||
random.seed(seed_value)
|
||||
|
||||
vectors = [[random.random() for _ in range(2)] for _ in range(10)]
|
||||
|
||||
points = [
|
||||
PointStruct(
|
||||
id=idx, vector=vector, payload={"color": "red", "rand_number": idx % 10}
|
||||
)
|
||||
for idx, vector in enumerate(vectors)
|
||||
]
|
||||
|
||||
|
||||
def test_milvus_store():
|
||||
qdrant_connection = QdrantConnection(memory=True)
|
||||
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
|
||||
qdrant_store = QdrantStore(qdrant_connection)
|
||||
qdrant_store.create_collection("Book", vectors_config, force_recreate=True)
|
||||
assert qdrant_store.has_collection("Book") is True
|
||||
qdrant_store.delete_collection("Book")
|
||||
assert qdrant_store.has_collection("Book") is False
|
||||
qdrant_store.create_collection("Book", vectors_config)
|
||||
assert qdrant_store.has_collection("Book") is True
|
||||
qdrant_store.add("Book", points)
|
||||
results = qdrant_store.search("Book", query=[1.0, 1.0])
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
|
||||
assert results[0]["id"] == 2
|
||||
assert results[0]["score"] == 0.999106722578389
|
||||
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
|
||||
assert results[1]["score"] == 7
|
||||
assert results[1]["score"] == 0.9961650411397226
|
||||
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
|
||||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
)
|
||||
assert results[0]["id"] == 8
|
||||
assert results[0]["score"] == 0.9100373450784073
|
||||
assert results[1]["id"] == 9
|
||||
assert results[1]["score"] == 0.7127610621127889
|
||||
results = qdrant_store.search(
|
||||
"Book",
|
||||
query=[1.0, 1.0],
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="rand_number", range=Range(gte=8))]
|
||||
),
|
||||
return_vector=True,
|
||||
)
|
||||
assert results[0]["vector"] == [0.35037919878959656, 0.9366079568862915]
|
||||
assert results[1]["vector"] == [0.9999677538871765, 0.00802854634821415]
|
||||
32
tests/metagpt/roles/test_researcher.py
Normal file
32
tests/metagpt/roles/test_researcher.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.roles import researcher
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
return f"# Research Report\n## Introduction\n{prompt}"
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
|
@ -2,22 +2,19 @@
|
|||
# @Date : 2023/7/15 16:40
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
import os
|
||||
from importlib import import_module
|
||||
import re
|
||||
from functools import wraps
|
||||
from importlib import import_module
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.roles import ProductManager, Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.actions import Action, ActionOutput, WritePRD
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.software_company import SoftwareCompany
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
PROMPT_TEMPLATE = '''
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
|
||||
|
|
@ -34,9 +31,9 @@ Attention: Use '##' to split sections, not '#', and '## <SECTION_NAME>' SHOULD W
|
|||
## CSS Styles (styles.css):Provide as Plain text,use standard css code
|
||||
## Anything UNCLEAR:Provide as Plain text. Make clear here.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
FORMAT_EXAMPLE = '''
|
||||
FORMAT_EXAMPLE = """
|
||||
|
||||
## UI Design Description
|
||||
```Snake games are classic and addictive games with simple yet engaging elements. Here are the main elements commonly found in snake games ```
|
||||
|
|
@ -126,7 +123,7 @@ body {
|
|||
## Anything UNCLEAR
|
||||
There are no unclear points.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
OUTPUT_MAPPING = {
|
||||
"UI Design Description": (str, ...),
|
||||
|
|
@ -139,25 +136,25 @@ OUTPUT_MAPPING = {
|
|||
|
||||
def load_engine(func):
|
||||
"""Decorator to load an engine by file name and engine name."""
|
||||
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
file_name, engine_name = func(*args, **kwargs)
|
||||
engine_file = import_module(file_name, package='metagpt')
|
||||
engine_file = import_module(file_name, package="metagpt")
|
||||
ip_module_cls = getattr(engine_file, engine_name)
|
||||
try:
|
||||
engine = ip_module_cls()
|
||||
except:
|
||||
engine = None
|
||||
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def parse(func):
|
||||
"""Decorator to parse information using regex pattern."""
|
||||
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
context, pattern = func(*args, **kwargs)
|
||||
|
|
@ -168,30 +165,30 @@ def parse(func):
|
|||
else:
|
||||
text_info = context
|
||||
logger.info("未找到匹配的内容")
|
||||
|
||||
|
||||
return text_info
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class UIDesign(Action):
|
||||
"""Class representing the UI Design action."""
|
||||
|
||||
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm) # 需要调用LLM进一步丰富UI设计的prompt
|
||||
|
||||
|
||||
@parse
|
||||
def parse_requirement(self, context: str):
|
||||
"""Parse UI Design draft from the context using regex."""
|
||||
pattern = r"## UI Design draft.*?\n(.*?)## Anything UNCLEAR"
|
||||
return context, pattern
|
||||
|
||||
|
||||
@parse
|
||||
def parse_ui_elements(self, context: str):
|
||||
"""Parse Selected Elements from the context using regex."""
|
||||
pattern = r"## Selected Elements.*?\n(.*?)## HTML Layout"
|
||||
return context, pattern
|
||||
|
||||
|
||||
@parse
|
||||
def parse_css_code(self, context: str):
|
||||
pattern = r"```css.*?\n(.*?)## Anything UNCLEAR"
|
||||
|
|
@ -201,7 +198,7 @@ class UIDesign(Action):
|
|||
def parse_html_code(self, context: str):
|
||||
pattern = r"```html.*?\n(.*?)```"
|
||||
return context, pattern
|
||||
|
||||
|
||||
async def draw_icons(self, context, *args, **kwargs):
|
||||
"""Draw icons using SDEngine."""
|
||||
engine = SDEngine()
|
||||
|
|
@ -215,20 +212,20 @@ class UIDesign(Action):
|
|||
prompts_batch.append(prompt)
|
||||
await engine.run_t2i(prompts_batch)
|
||||
logger.info("Finish icon design using StableDiffusion API")
|
||||
|
||||
|
||||
async def _save(self, css_content, html_content):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / 'codes'
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "codes"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Save CSS and HTML content to files
|
||||
css_file_path = save_dir / f"ui_design.css"
|
||||
html_file_path = save_dir / f"ui_design.html"
|
||||
|
||||
with open(css_file_path, 'w') as css_file:
|
||||
css_file_path = save_dir / "ui_design.css"
|
||||
html_file_path = save_dir / "ui_design.html"
|
||||
|
||||
with open(css_file_path, "w") as css_file:
|
||||
css_file.write(css_content)
|
||||
with open(html_file_path, 'w') as html_file:
|
||||
with open(html_file_path, "w") as html_file:
|
||||
html_file.write(html_content)
|
||||
|
||||
|
||||
async def run(self, requirements: list[Message], *args, **kwargs) -> ActionOutput:
|
||||
"""Run the UI Design action."""
|
||||
# fixme: update prompt (根据需求细化prompt)
|
||||
|
|
@ -249,23 +246,27 @@ class UIDesign(Action):
|
|||
|
||||
class UI(Role):
|
||||
"""Class representing the UI Role."""
|
||||
|
||||
def __init__(self, name="Catherine", profile="UI Design",
|
||||
goal="Finish a workable and good User Interface design based on a product design",
|
||||
constraints="Give clear layout description and use standard icons to finish the design",
|
||||
skills=["SD"]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name="Catherine",
|
||||
profile="UI Design",
|
||||
goal="Finish a workable and good User Interface design based on a product design",
|
||||
constraints="Give clear layout description and use standard icons to finish the design",
|
||||
skills=["SD"],
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self.load_skills(skills)
|
||||
self._init_actions([UIDesign])
|
||||
self._watch([WritePRD])
|
||||
|
||||
|
||||
@load_engine
|
||||
def load_sd_engine(self):
|
||||
"""Load the SDEngine."""
|
||||
file_name = ".tools.sd_engine"
|
||||
engine_name = "SDEngine"
|
||||
return file_name, engine_name
|
||||
|
||||
|
||||
def load_skills(self, skills):
|
||||
"""Load skills for the UI Role."""
|
||||
# todo: 添加其他出图engine
|
||||
|
|
@ -273,4 +274,3 @@ class UI(Role):
|
|||
if skill == "SD":
|
||||
self.sd_engine = self.load_sd_engine()
|
||||
logger.info(f"load skill engine {self.sd_engine}")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,24 +5,44 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_search_engine.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
async def test_search_engine(llm_api):
|
||||
search_engine = SearchEngine()
|
||||
poetries = [
|
||||
# ("北京美食", "北京"),
|
||||
("屈臣氏", "屈臣氏")
|
||||
]
|
||||
for i, j in poetries:
|
||||
rsp = await search_engine.run(i)
|
||||
# rsp = context.llm.ask_batch([prompt])
|
||||
logger.info(rsp)
|
||||
# assert any(j in k['body'] for k in rsp)
|
||||
assert len(rsp) > 0
|
||||
@pytest.mark.parametrize(
|
||||
("search_engine_typpe", "run_func", "max_results", "as_string"),
|
||||
[
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) == max_results
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
from metagpt.config import Config
|
||||
from metagpt.tools import web_browser_engine, WebBrowserEngineType
|
||||
|
||||
from metagpt.tools import WebBrowserEngineType, web_browser_engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_playwright
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_selenium
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
|
||||
|
|
@ -27,7 +29,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("Deepwisdom" in i) for i in results)
|
||||
assert all(("Deepwisdom" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
|
|
|
|||
68
tests/metagpt/utils/test_parse_html.py
Normal file
68
tests/metagpt/utils/test_parse_html.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from metagpt.utils import parse_html
|
||||
|
||||
PAGE = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Random HTML Example</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>This is a Heading</h1>
|
||||
<p>This is a paragraph with <a href="test">a link</a> and some <em>emphasized</em> text.</p>
|
||||
<ul>
|
||||
<li>Item 1</li>
|
||||
<li>Item 2</li>
|
||||
<li>Item 3</li>
|
||||
</ul>
|
||||
<ol>
|
||||
<li>Numbered Item 1</li>
|
||||
<li>Numbered Item 2</li>
|
||||
<li>Numbered Item 3</li>
|
||||
</ol>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Header 1</th>
|
||||
<th>Header 2</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Row 1, Cell 1</td>
|
||||
<td>Row 1, Cell 2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Row 2, Cell 1</td>
|
||||
<td>Row 2, Cell 2</td>
|
||||
</tr>
|
||||
</table>
|
||||
<img src="image.jpg" alt="Sample Image">
|
||||
<form action="/submit" method="post">
|
||||
<label for="name">Name:</label>
|
||||
<input type="text" id="name" name="name" required>
|
||||
<label for="email">Email:</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
<div class="box">
|
||||
<p>This is a div with a class "box".</p>
|
||||
<p><a href="https://metagpt.com">a link</a></p>
|
||||
<p><a href="#section2"></a></p>
|
||||
<p><a href="ftp://192.168.1.1:8080"></a></p>
|
||||
<p><a href="javascript:alert('Hello');"></a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
|
||||
'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
|
||||
'with a class "box".a link'
|
||||
|
||||
|
||||
def test_web_page():
|
||||
page = parse_html.WebPage(inner_text=CONTENT, html=PAGE, url="http://example.com")
|
||||
assert page.title == "Random HTML Example"
|
||||
assert list(page.get_links()) == ["http://example.com/test", "https://metagpt.com"]
|
||||
|
||||
|
||||
def test_get_page_content():
|
||||
ret = parse_html.get_html_content(PAGE, "http://example.com")
|
||||
assert ret == CONTENT
|
||||
|
|
@ -3,94 +3,64 @@
|
|||
# @Desc : the unittest of serialize
|
||||
|
||||
from typing import List, Tuple
|
||||
import pytest
|
||||
|
||||
from pydantic import create_model
|
||||
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import actionoutout_schema_to_mapping, serialize_message, deserialize_message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
deserialize_message,
|
||||
serialize_message,
|
||||
)
|
||||
|
||||
|
||||
def test_actionoutout_schema_to_mapping():
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
schema = {"title": "test", "type": "object", "properties": {"field": {"title": "field", "type": "string"}}}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (str, ...)
|
||||
assert mapping["field"] == (str, ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'string'
|
||||
}
|
||||
}
|
||||
}
|
||||
"title": "test",
|
||||
"type": "object",
|
||||
"properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[str], ...)
|
||||
assert mapping["field"] == (List[str], ...)
|
||||
|
||||
schema = {
|
||||
'title': 'test',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'field': {
|
||||
'title': 'field',
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'array',
|
||||
'minItems': 2,
|
||||
'maxItems': 2,
|
||||
'items': [
|
||||
{
|
||||
'type': 'string'
|
||||
},
|
||||
{
|
||||
'type': 'string'
|
||||
}
|
||||
]
|
||||
}
|
||||
"title": "test",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"title": "field",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
"items": [{"type": "string"}, {"type": "string"}],
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping['field'] == (List[Tuple[str, str]], ...)
|
||||
assert mapping["field"] == (List[Tuple[str, str]], ...)
|
||||
|
||||
assert True, True
|
||||
|
||||
|
||||
def test_serialize_and_deserialize_message():
|
||||
out_mapping = {
|
||||
'field1': (str, ...),
|
||||
'field2': (List[str], ...)
|
||||
}
|
||||
out_data = {
|
||||
'field1': 'field1 value',
|
||||
'field2': ['field2 value1', 'field2 value2']
|
||||
}
|
||||
ic_obj = ActionOutput.create_model_class('prd', out_mapping)
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(content='prd demand',
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role='user',
|
||||
cause_by=WritePRD) # WritePRD as test action
|
||||
message = Message(
|
||||
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
message_ser = serialize_message(message)
|
||||
|
||||
new_message = deserialize_message(message_ser)
|
||||
assert new_message.content == message.content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field1 == out_data['field1']
|
||||
assert new_message.instruct_content.field1 == out_data["field1"]
|
||||
|
|
|
|||
77
tests/metagpt/utils/test_text.py
Normal file
77
tests/metagpt/utils/test_text.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.utils.text import (
|
||||
decode_unicode_escape,
|
||||
generate_prompt_chunk,
|
||||
reduce_message_length,
|
||||
split_paragraph,
|
||||
)
|
||||
|
||||
|
||||
def _msgs():
|
||||
length = 20
|
||||
while length:
|
||||
yield "Hello," * 1000 * length
|
||||
length -= 1
|
||||
|
||||
|
||||
def _paragraphs(n):
|
||||
return " ".join("Hello World." for _ in range(n))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
|
||||
(_msgs(), "gpt-4-32k", "System", 4000, 14),
|
||||
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
|
||||
]
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
]
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"paragraph, sep, count, expected",
|
||||
[
|
||||
(_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
|
||||
(_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
|
||||
(f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
|
||||
("......", ".", 2, ["...", "..."]),
|
||||
("......", ".", 3, ["..", "..", ".."]),
|
||||
(".......", ".", 2, ["....", "..."]),
|
||||
]
|
||||
)
|
||||
def test_split_paragraph(paragraph, sep, count, expected):
|
||||
ret = split_paragraph(paragraph, sep, count)
|
||||
assert ret == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("Hello\\nWorld", "Hello\nWorld"),
|
||||
("Hello\\tWorld", "Hello\tWorld"),
|
||||
("Hello\\u0020World", "Hello World"),
|
||||
]
|
||||
)
|
||||
def test_decode_unicode_escape(text, expected):
|
||||
assert decode_unicode_escape(text) == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue