mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
Merge pull request #736 from garylin2099/code_intepreter
Integrate CodeIntepreter
This commit is contained in:
commit
23c27627ce
61 changed files with 5076 additions and 84 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -172,8 +172,10 @@ tests/metagpt/utils/file_repo_git
|
|||
*.png
|
||||
htmlcov
|
||||
htmlcov.*
|
||||
cov.xml
|
||||
*.dot
|
||||
*.pkl
|
||||
*.faiss
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
metagpt/tools/schemas
|
||||
22
examples/crawl_webpage.py
Normal file
22
examples/crawl_webpage.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Date : 2024/01/24 15:11:27
|
||||
@Author : orange-crow
|
||||
@File : crawl_webpage.py
|
||||
"""
|
||||
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
prompt = """Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
|
||||
and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key data*"""
|
||||
ci = CodeInterpreter(goal=prompt, use_tools=True)
|
||||
|
||||
await ci.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
26
examples/imitate_webpage.py
Normal file
26
examples/imitate_webpage.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : imitate_webpage.py
|
||||
"""
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
web_url = "https://pytorch.org/"
|
||||
prompt = f"""This is a URL of webpage: '{web_url}' .
|
||||
Firstly, utilize Selenium and WebDriver for rendering.
|
||||
Secondly, convert image to a webpage including HTML, CSS and JS in one go.
|
||||
Finally, save webpage in a text file.
|
||||
Note: All required dependencies and environments have been fully installed and configured."""
|
||||
ci = CodeInterpreter(goal=prompt, use_tools=True)
|
||||
|
||||
await ci.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
21
examples/sd_tool_usage.py
Normal file
21
examples/sd_tool_usage.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 7:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
code_interpreter = CodeInterpreter(use_tools=True, goal=requirement)
|
||||
await code_interpreter.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sd_url = "http://your.sd.service.ip:port"
|
||||
requirement = (
|
||||
f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}"
|
||||
)
|
||||
|
||||
asyncio.run(main(requirement))
|
||||
|
|
@ -22,6 +22,9 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.write_analysis_code import WriteCodeWithoutTools, WriteCodeWithTools
|
||||
from metagpt.actions.ci.write_plan import WritePlan
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
|
@ -42,6 +45,10 @@ class ActionType(Enum):
|
|||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
EXECUTE_NB_CODE = ExecuteNbCode
|
||||
WRITE_CODE_WITHOUT_TOOLS = WriteCodeWithoutTools
|
||||
WRITE_CODE_WITH_TOOLS = WriteCodeWithTools
|
||||
WRITE_PLAN = WritePlan
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
62
metagpt/actions/ci/ask_review.py
Normal file
62
metagpt/actions/ci/ask_review.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message, Plan
|
||||
|
||||
|
||||
class ReviewConst:
|
||||
TASK_REVIEW_TRIGGER = "task"
|
||||
CODE_REVIEW_TRIGGER = "code"
|
||||
CONTINUE_WORDS = ["confirm", "continue", "c", "yes", "y"]
|
||||
CHANGE_WORDS = ["change"]
|
||||
EXIT_WORDS = ["exit"]
|
||||
TASK_REVIEW_INSTRUCTION = (
|
||||
f"If you want to change, add, delete a task or merge tasks in the plan, say '{CHANGE_WORDS[0]} task task_id or current task, ... (things to change)' "
|
||||
f"If you confirm the output from the current task and wish to continue, type: {CONTINUE_WORDS[0]}"
|
||||
)
|
||||
CODE_REVIEW_INSTRUCTION = (
|
||||
f"If you want the codes to be rewritten, say '{CHANGE_WORDS[0]} ... (your change advice)' "
|
||||
f"If you want to leave it as is, type: {CONTINUE_WORDS[0]} or {CONTINUE_WORDS[1]}"
|
||||
)
|
||||
EXIT_INSTRUCTION = f"If you want to terminate the process, type: {EXIT_WORDS[0]}"
|
||||
|
||||
|
||||
class AskReview(Action):
|
||||
async def run(
|
||||
self, context: list[Message] = [], plan: Plan = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER
|
||||
) -> Tuple[str, bool]:
|
||||
if plan:
|
||||
logger.info("Current overall plan:")
|
||||
logger.info(
|
||||
"\n".join(
|
||||
[f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Most recent context:")
|
||||
latest_action = context[-1].cause_by if context and context[-1].cause_by else ""
|
||||
review_instruction = (
|
||||
ReviewConst.TASK_REVIEW_INSTRUCTION
|
||||
if trigger == ReviewConst.TASK_REVIEW_TRIGGER
|
||||
else ReviewConst.CODE_REVIEW_INSTRUCTION
|
||||
)
|
||||
prompt = (
|
||||
f"This is a <{trigger}> review. Please review output from {latest_action}\n"
|
||||
f"{review_instruction}\n"
|
||||
f"{ReviewConst.EXIT_INSTRUCTION}\n"
|
||||
"Please type your review below:\n"
|
||||
)
|
||||
|
||||
rsp = input(prompt)
|
||||
|
||||
if rsp.lower() in ReviewConst.EXIT_WORDS:
|
||||
exit()
|
||||
|
||||
# Confirmation can be one of "confirm", "continue", "c", "yes", "y" exactly, or sentences containing "confirm".
|
||||
# One could say "confirm this task, but change the next task to ..."
|
||||
confirmed = rsp.lower() in ReviewConst.CONTINUE_WORDS or ReviewConst.CONTINUE_WORDS[0] in rsp.lower()
|
||||
|
||||
return rsp, confirmed
|
||||
109
metagpt/actions/ci/debug_code.py
Normal file
109
metagpt/actions/ci/debug_code.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from metagpt.actions.ci.write_analysis_code import BaseWriteAnalysisCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import create_func_call_config
|
||||
|
||||
DEBUG_REFLECTION_EXAMPLE = '''
|
||||
Example 1:
|
||||
[previous impl]:
|
||||
```python
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a - b
|
||||
```
|
||||
|
||||
[runtime Error]:
|
||||
Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert add(1, 2) == 3 # output: -1
|
||||
assert add(1, 2) == 4 # output: -1
|
||||
|
||||
[reflection on previous impl]:
|
||||
The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
|
||||
|
||||
[improved impl]:
|
||||
```python
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a + b
|
||||
```
|
||||
'''
|
||||
|
||||
REFLECTION_PROMPT = """
|
||||
Here is an example for you.
|
||||
{debug_example}
|
||||
[context]
|
||||
{context}
|
||||
|
||||
[previous impl]
|
||||
{code}
|
||||
[runtime Error]
|
||||
{runtime_result}
|
||||
|
||||
Analysis the error step by step, provide me improve method and code. Remember to follow [context] requirement. Don't forget write code for steps behind the error step.
|
||||
[reflection on previous impl]:
|
||||
xxx
|
||||
"""
|
||||
|
||||
CODE_REFLECTION = {
|
||||
"name": "execute_reflection_code",
|
||||
"description": "Execute reflection code.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reflection": {
|
||||
"type": "string",
|
||||
"description": "Reflection on previous impl.",
|
||||
},
|
||||
"improved_impl": {
|
||||
"type": "string",
|
||||
"description": "Refined code after reflection.",
|
||||
},
|
||||
},
|
||||
"required": ["reflection", "improved_impl"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DebugCode(BaseWriteAnalysisCode):
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message] = None,
|
||||
code: str = "",
|
||||
runtime_result: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Execute the debugging process based on the provided context, code, and runtime_result.
|
||||
|
||||
Args:
|
||||
context (list[Message]): A list of Message objects representing the context.
|
||||
code (str): The code to be debugged.
|
||||
runtime_result (str): The result of the code execution.
|
||||
|
||||
Returns:
|
||||
str: The improved implementation based on the debugging process.
|
||||
"""
|
||||
|
||||
info = []
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
code=code,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
system_prompt = "You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation "
|
||||
info.append(Message(role="system", content=system_prompt))
|
||||
info.append(Message(role="user", content=reflection_prompt))
|
||||
|
||||
tool_config = create_func_call_config(CODE_REFLECTION)
|
||||
reflection = await self.llm.aask_code(messages=info, **tool_config)
|
||||
logger.info(f"reflection is {reflection}")
|
||||
|
||||
return {"code": reflection["improved_impl"]}
|
||||
249
metagpt/actions/ci/execute_nb_code.py
Normal file
249
metagpt/actions/ci/execute_nb_code.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Date : 2023/11/17 14:22:15
|
||||
@Author : orange-crow
|
||||
@File : execute_nb_code.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import re
|
||||
import traceback
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import nbformat
|
||||
from nbclient import NotebookClient
|
||||
from nbclient.exceptions import CellTimeoutError, DeadKernelError
|
||||
from nbformat import NotebookNode
|
||||
from nbformat.v4 import new_code_cell, new_markdown_cell, new_output
|
||||
from rich.box import MINIMAL
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class ExecuteNbCode(Action):
|
||||
"""execute notebook code block, return result to llm, and display it."""
|
||||
|
||||
nb: NotebookNode
|
||||
nb_client: NotebookClient
|
||||
console: Console
|
||||
interaction: str
|
||||
timeout: int = 600
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nb=nbformat.v4.new_notebook(),
|
||||
timeout=600,
|
||||
):
|
||||
super().__init__(
|
||||
nb=nb,
|
||||
nb_client=NotebookClient(nb, timeout=timeout),
|
||||
timeout=timeout,
|
||||
console=Console(),
|
||||
interaction=("ipython" if self.is_ipython() else "terminal"),
|
||||
)
|
||||
|
||||
async def build(self):
|
||||
if self.nb_client.kc is None or not await self.nb_client.kc.is_alive():
|
||||
self.nb_client.create_kernel_manager()
|
||||
self.nb_client.start_new_kernel()
|
||||
self.nb_client.start_new_kernel_client()
|
||||
|
||||
async def terminate(self):
|
||||
"""kill NotebookClient"""
|
||||
await self.nb_client._async_cleanup_kernel()
|
||||
|
||||
async def reset(self):
|
||||
"""reset NotebookClient"""
|
||||
await self.terminate()
|
||||
|
||||
# sleep 1s to wait for the kernel to be cleaned up completely
|
||||
await asyncio.sleep(1)
|
||||
await self.build()
|
||||
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
|
||||
|
||||
def add_code_cell(self, code: str):
|
||||
self.nb.cells.append(new_code_cell(source=code))
|
||||
|
||||
def add_markdown_cell(self, markdown: str):
|
||||
self.nb.cells.append(new_markdown_cell(source=markdown))
|
||||
|
||||
def _display(self, code: str, language: Literal["python", "markdown"] = "python"):
|
||||
if language == "python":
|
||||
code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True)
|
||||
self.console.print(code)
|
||||
elif language == "markdown":
|
||||
display_markdown(code)
|
||||
else:
|
||||
raise ValueError(f"Only support for python, markdown, but got {language}")
|
||||
|
||||
def add_output_to_cell(self, cell: NotebookNode, output: str):
|
||||
"""add outputs of code execution to notebook cell."""
|
||||
if "outputs" not in cell:
|
||||
cell["outputs"] = []
|
||||
else:
|
||||
cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output)))
|
||||
|
||||
def parse_outputs(self, outputs: list[str]) -> str:
|
||||
"""Parses the outputs received from notebook execution."""
|
||||
assert isinstance(outputs, list)
|
||||
parsed_output = ""
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
if output["output_type"] == "stream" and not any(
|
||||
tag in output["text"]
|
||||
for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt"]
|
||||
):
|
||||
parsed_output += output["text"]
|
||||
elif output["output_type"] == "display_data":
|
||||
if "image/png" in output["data"]:
|
||||
self.show_bytes_figure(output["data"]["image/png"], self.interaction)
|
||||
else:
|
||||
logger.info(
|
||||
f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..."
|
||||
)
|
||||
elif output["output_type"] == "execute_result":
|
||||
parsed_output += output["data"]["text/plain"]
|
||||
return parsed_output
|
||||
|
||||
def show_bytes_figure(self, image_base64: str, interaction_type: Literal["ipython", None]):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
if interaction_type == "ipython":
|
||||
from IPython.display import Image, display
|
||||
|
||||
display(Image(data=image_bytes))
|
||||
else:
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image.show()
|
||||
|
||||
def is_ipython(self) -> bool:
|
||||
try:
|
||||
# 如果在Jupyter Notebook中运行,__file__ 变量不存在
|
||||
from IPython import get_ipython
|
||||
|
||||
if get_ipython() is not None and "IPKernelApp" in get_ipython().config:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except NameError:
|
||||
return False
|
||||
|
||||
async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]:
|
||||
"""set timeout for run code.
|
||||
returns the success or failure of the cell execution, and an optional error message.
|
||||
"""
|
||||
try:
|
||||
await self.nb_client.async_execute_cell(cell, cell_index)
|
||||
return True, ""
|
||||
except CellTimeoutError:
|
||||
assert self.nb_client.km is not None
|
||||
await self.nb_client.km.interrupt_kernel()
|
||||
await asyncio.sleep(1)
|
||||
error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance."
|
||||
return False, error_msg
|
||||
except DeadKernelError:
|
||||
await self.reset()
|
||||
return False, "DeadKernelError"
|
||||
except Exception:
|
||||
return False, f"{traceback.format_exc()}"
|
||||
|
||||
async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]:
|
||||
"""
|
||||
return the output of code execution, and a success indicator (bool) of code execution.
|
||||
"""
|
||||
self._display(code, language)
|
||||
|
||||
if language == "python":
|
||||
# add code to the notebook
|
||||
self.add_code_cell(code=code)
|
||||
|
||||
# build code executor
|
||||
await self.build()
|
||||
|
||||
# run code
|
||||
cell_index = len(self.nb.cells) - 1
|
||||
success, error_message = await self.run_cell(self.nb.cells[-1], cell_index)
|
||||
|
||||
if not success:
|
||||
return truncate(remove_escape_and_color_codes(error_message), is_success=success)
|
||||
|
||||
# code success
|
||||
outputs = self.parse_outputs(self.nb.cells[-1].outputs)
|
||||
outputs, success = truncate(remove_escape_and_color_codes(outputs), is_success=success)
|
||||
|
||||
if "!pip" in outputs:
|
||||
success = False
|
||||
|
||||
return outputs, success
|
||||
|
||||
elif language == "markdown":
|
||||
# add markdown content to markdown cell in a notebook.
|
||||
self.add_markdown_cell(code)
|
||||
# return True, beacuse there is no execution failure for markdown cell.
|
||||
return code, True
|
||||
else:
|
||||
raise ValueError(f"Only support for language: python, markdown, but got {language}, ")
|
||||
|
||||
|
||||
def truncate(result: str, keep_len: int = 2000, is_success: bool = True):
|
||||
"""对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。"""
|
||||
if is_success:
|
||||
desc = f"Executed code successfully. Truncated to show only first {keep_len} characters\n"
|
||||
else:
|
||||
desc = f"Executed code failed, please reflect the cause of bug and then debug. Truncated to show only last {keep_len} characters\n"
|
||||
|
||||
if result.strip().startswith("<coroutine object"):
|
||||
result = "Executed code failed, you need use key word 'await' to run a async code."
|
||||
return result, False
|
||||
|
||||
if len(result) > keep_len:
|
||||
result = result[-keep_len:] if not is_success else result[:keep_len]
|
||||
return desc + result, is_success
|
||||
|
||||
return result, is_success
|
||||
|
||||
|
||||
def remove_escape_and_color_codes(input_str: str):
|
||||
# 使用正则表达式去除转义字符和颜色代码
|
||||
pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
|
||||
result = pattern.sub("", input_str)
|
||||
return result
|
||||
|
||||
|
||||
def display_markdown(content: str):
|
||||
# 使用正则表达式逐个匹配代码块
|
||||
matches = re.finditer(r"```(.+?)```", content, re.DOTALL)
|
||||
start_index = 0
|
||||
content_panels = []
|
||||
# 逐个打印匹配到的文本和代码
|
||||
for match in matches:
|
||||
text_content = content[start_index : match.start()].strip()
|
||||
code_content = match.group(0).strip()[3:-3] # Remove triple backticks
|
||||
|
||||
if text_content:
|
||||
content_panels.append(Panel(Markdown(text_content), box=MINIMAL))
|
||||
|
||||
if code_content:
|
||||
content_panels.append(Panel(Markdown(f"```{code_content}"), box=MINIMAL))
|
||||
start_index = match.end()
|
||||
|
||||
# 打印剩余文本(如果有)
|
||||
remaining_text = content[start_index:].strip()
|
||||
if remaining_text:
|
||||
content_panels.append(Panel(Markdown(remaining_text), box=MINIMAL))
|
||||
|
||||
# 在Live模式中显示所有Panel
|
||||
with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live:
|
||||
live.update(Group(*content_panels))
|
||||
live.refresh()
|
||||
70
metagpt/actions/ci/ml_action.py
Normal file
70
metagpt/actions/ci/ml_action.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.ci.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.prompts.ci.ml_action import (
|
||||
ML_GENERATE_CODE_PROMPT,
|
||||
ML_TOOL_USAGE_PROMPT,
|
||||
PRINT_DATA_COLUMNS,
|
||||
UPDATE_DATA_COLUMNS,
|
||||
)
|
||||
from metagpt.prompts.ci.write_analysis_code import CODE_GENERATOR_WITH_TOOLS
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import create_func_call_config, remove_comments
|
||||
|
||||
|
||||
class WriteCodeWithToolsML(WriteCodeWithTools):
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message],
|
||||
plan: Plan = None,
|
||||
column_info: str = "",
|
||||
**kwargs,
|
||||
) -> Tuple[list[Message], str]:
|
||||
# prepare tool schemas and tool-type-specific instruction
|
||||
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
|
||||
|
||||
# ML-specific variables to be used in prompt
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
|
||||
# prepare prompt depending on tool availability & LLM call
|
||||
if tool_schemas:
|
||||
prompt = ML_TOOL_USAGE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
current_task=plan.current_task.instruction,
|
||||
column_info=column_info,
|
||||
tool_type_usage_prompt=tool_type_usage_prompt,
|
||||
tool_schemas=tool_schemas,
|
||||
)
|
||||
|
||||
else:
|
||||
prompt = ML_GENERATE_CODE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
current_task=plan.current_task.instruction,
|
||||
column_info=column_info,
|
||||
tool_type_usage_prompt=tool_type_usage_prompt,
|
||||
)
|
||||
tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
|
||||
# Extra output to be used for potential debugging
|
||||
context = [Message(content=prompt, role="user")]
|
||||
|
||||
return context, rsp
|
||||
|
||||
|
||||
class UpdateDataColumns(Action):
|
||||
async def run(self, plan: Plan = None) -> dict:
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context)
|
||||
tool_config = create_func_call_config(PRINT_DATA_COLUMNS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
return rsp
|
||||
155
metagpt/actions/ci/write_analysis_code.py
Normal file
155
metagpt/actions/ci/write_analysis_code.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Date : 2023/11/20 13:19:39
|
||||
@Author : orange-crow
|
||||
@File : write_analysis_code.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.ci.write_analysis_code import (
|
||||
CODE_GENERATOR_WITH_TOOLS,
|
||||
SELECT_FUNCTION_TOOLS,
|
||||
TOOL_RECOMMENDATION_PROMPT,
|
||||
TOOL_USAGE_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan, SystemMessage
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_registry import validate_tool_names
|
||||
from metagpt.utils.common import create_func_call_config
|
||||
|
||||
|
||||
class BaseWriteAnalysisCode(Action):
|
||||
DEFAULT_SYSTEM_MSG: str = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.Usually the libraries you need are already installed.Dont check if packages already imported.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
|
||||
# REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!"""
|
||||
|
||||
def insert_system_message(self, context: list[Message], system_msg: str = None):
|
||||
system_msg = system_msg or self.DEFAULT_SYSTEM_MSG
|
||||
context.insert(0, SystemMessage(content=system_msg)) if context[0].role != "system" else None
|
||||
return context
|
||||
|
||||
async def run(self, context: list[Message], plan: Plan = None) -> dict:
|
||||
"""Run of a code writing action, used in data analysis or modeling
|
||||
|
||||
Args:
|
||||
context (list[Message]): Action output history, source action denoted by Message.cause_by
|
||||
plan (Plan, optional): Overall plan. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: code result in the format of {"code": "print('hello world')", "language": "python"}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteCodeWithoutTools(BaseWriteAnalysisCode):
|
||||
"""Ask LLM to generate codes purely by itself without local user-defined tools"""
|
||||
|
||||
async def run(self, context: list[Message], plan: Plan = None, system_msg: str = None, **kwargs) -> dict:
|
||||
messages = self.insert_system_message(context, system_msg)
|
||||
rsp = await self.llm.aask_code(messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
|
||||
class WriteCodeWithTools(BaseWriteAnalysisCode):
|
||||
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
|
||||
|
||||
# selected tools to choose from, listed by their names. An empty list means selection from all tools.
|
||||
selected_tools: list[str] = []
|
||||
|
||||
def _get_tools_by_type(self, tool_type: str) -> dict:
|
||||
"""
|
||||
Retreive tools by tool type from registry, but filtered by pre-selected tool list
|
||||
|
||||
Args:
|
||||
tool_type (str): Tool type to retrieve from the registry
|
||||
|
||||
Returns:
|
||||
dict: A dict of tool name to Tool object, representing available tools under the type
|
||||
"""
|
||||
candidate_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
|
||||
if self.selected_tools:
|
||||
candidate_tool_names = set(self.selected_tools) & candidate_tools.keys()
|
||||
candidate_tools = {tool_name: candidate_tools[tool_name] for tool_name in candidate_tool_names}
|
||||
return candidate_tools
|
||||
|
||||
async def _recommend_tool(
|
||||
self,
|
||||
task: str,
|
||||
available_tools: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Recommend tools for the specified task.
|
||||
|
||||
Args:
|
||||
task (str): the task to recommend tools for
|
||||
available_tools (dict): the available tools description
|
||||
|
||||
Returns:
|
||||
dict: schemas of recommended tools for the specified task
|
||||
"""
|
||||
prompt = TOOL_RECOMMENDATION_PROMPT.format(
|
||||
current_task=task,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
tool_config = create_func_call_config(SELECT_FUNCTION_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
recommend_tools = rsp["recommend_tools"]
|
||||
logger.info(f"Recommended tools: \n{recommend_tools}")
|
||||
|
||||
# Parses and validates the recommended tools, for LLM might hallucinate and recommend non-existing tools
|
||||
valid_tools = validate_tool_names(recommend_tools, return_tool_object=True)
|
||||
|
||||
tool_schemas = {tool.name: tool.schemas for tool in valid_tools}
|
||||
|
||||
return tool_schemas
|
||||
|
||||
async def _prepare_tools(self, plan: Plan) -> Tuple[dict, str]:
|
||||
"""Prepare tool schemas and usage instructions according to current task
|
||||
|
||||
Args:
|
||||
plan (Plan): The overall plan containing task information.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, str]: A tool schemas ({tool_name: tool_schema_dict}) and a usage prompt for the type of tools selected
|
||||
"""
|
||||
# find tool type from task type through exact match, can extend to retrieval in the future
|
||||
tool_type = plan.current_task.task_type
|
||||
|
||||
# prepare tool-type-specific instruction
|
||||
tool_type_usage_prompt = (
|
||||
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
|
||||
)
|
||||
|
||||
# prepare schemas of available tools
|
||||
tool_schemas = {}
|
||||
available_tools = self._get_tools_by_type(tool_type)
|
||||
if available_tools:
|
||||
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
|
||||
tool_schemas = await self._recommend_tool(plan.current_task.instruction, available_tools)
|
||||
|
||||
return tool_schemas, tool_type_usage_prompt
|
||||
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message],
|
||||
plan: Plan,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# prepare tool schemas and tool-type-specific instruction
|
||||
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
|
||||
|
||||
# form a complete tool usage instruction and include it as a message in context
|
||||
tools_instruction = TOOL_USAGE_PROMPT.format(
|
||||
tool_schemas=tool_schemas, tool_type_usage_prompt=tool_type_usage_prompt
|
||||
)
|
||||
context.append(Message(content=tools_instruction, role="user"))
|
||||
|
||||
# prepare prompt & LLM call
|
||||
prompt = self.insert_system_message(context)
|
||||
tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
|
||||
return rsp
|
||||
116
metagpt/actions/ci/write_plan.py
Normal file
116
metagpt/actions/ci/write_plan.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Date : 2023/11/20 11:24:03
|
||||
@Author : orange-crow
|
||||
@File : plan.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.ci.write_analysis_code import (
|
||||
ASSIGN_TASK_TYPE_CONFIG,
|
||||
ASSIGN_TASK_TYPE_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.utils.common import CodeParser, create_func_call_config
|
||||
|
||||
|
||||
class WritePlan(Action):
|
||||
PROMPT_TEMPLATE: str = """
|
||||
# Context:
|
||||
__context__
|
||||
# Task:
|
||||
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
|
||||
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan.
|
||||
If you encounter errors on the current task, revise and output the current single task only.
|
||||
Output a list of jsons following the format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"task_id": str = "unique identifier for a task in plan, can be an ordinal",
|
||||
"dependent_task_ids": list[str] = "ids of tasks prerequisite to this task",
|
||||
"instruction": "what you should do in this task, one short phrase or sentence",
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
async def assign_task_type(self, tasks: list[dict]) -> str:
|
||||
"""Assign task type to each task in tasks
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): tasks to be assigned task type
|
||||
|
||||
Returns:
|
||||
str: tasks with task type assigned in a json string
|
||||
"""
|
||||
task_info = "\n".join([f"Task {task['task_id']}: {task['instruction']}" for task in tasks])
|
||||
task_type_desc = "\n".join(
|
||||
[f"- **{tool_type.name}**: {tool_type.desc}" for tool_type in TOOL_REGISTRY.get_tool_types().values()]
|
||||
) # task type are binded with tool type now, should be improved in the future
|
||||
prompt = ASSIGN_TASK_TYPE_PROMPT.format(
|
||||
task_info=task_info, task_type_desc=task_type_desc
|
||||
) # task types are set to be the same as tool types, for now
|
||||
tool_config = create_func_call_config(ASSIGN_TASK_TYPE_CONFIG)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
task_type_list = rsp["task_type"]
|
||||
logger.info(f"assigned task types: {task_type_list}")
|
||||
for task, task_type in zip(tasks, task_type_list):
|
||||
task["task_type"] = task_type
|
||||
return json.dumps(tasks)
|
||||
|
||||
async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str:
|
||||
prompt = (
|
||||
self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
|
||||
# .replace("__current_plan__", current_plan)
|
||||
.replace("__max_tasks__", str(max_tasks))
|
||||
)
|
||||
rsp = await self._aask(prompt)
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
if use_tools:
|
||||
rsp = await self.assign_task_type(json.loads(rsp))
|
||||
return rsp
|
||||
|
||||
|
||||
def rsp_to_tasks(rsp: str) -> list[Task]:
|
||||
rsp = json.loads(rsp)
|
||||
tasks = [Task(**task_config) for task_config in rsp]
|
||||
return tasks
|
||||
|
||||
|
||||
def update_plan_from_rsp(rsp: str, current_plan: Plan):
|
||||
tasks = rsp_to_tasks(rsp)
|
||||
if len(tasks) == 1 or tasks[0].dependent_task_ids:
|
||||
if tasks[0].dependent_task_ids and len(tasks) > 1:
|
||||
# tasks[0].dependent_task_ids means the generated tasks are not a complete plan
|
||||
# for they depend on tasks in the current plan, in this case, we only support updating one task each time
|
||||
logger.warning(
|
||||
"Current plan will take only the first generated task if the generated tasks are not a complete plan"
|
||||
)
|
||||
# handle a single task
|
||||
if current_plan.has_task_id(tasks[0].task_id):
|
||||
# replace an existing task
|
||||
current_plan.replace_task(tasks[0])
|
||||
else:
|
||||
# append one task
|
||||
current_plan.append_task(tasks[0])
|
||||
|
||||
else:
|
||||
# add tasks in general
|
||||
current_plan.add_tasks(tasks)
|
||||
|
||||
|
||||
def precheck_update_plan_from_rsp(rsp: str, current_plan: Plan) -> Tuple[bool, str]:
|
||||
temp_plan = deepcopy(current_plan)
|
||||
try:
|
||||
update_plan_from_rsp(rsp, temp_plan)
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
|
@ -75,6 +75,8 @@ class Config(CLIParams, YamlModel):
|
|||
iflytek_api_key: str = ""
|
||||
azure_tts_subscription_key: str = ""
|
||||
azure_tts_region: str = ""
|
||||
openai_vision_model: str = "gpt-4-vision-preview"
|
||||
vision_max_tokens: int = 4096
|
||||
|
||||
@classmethod
|
||||
def from_home(cls, path):
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ TMP = METAGPT_ROOT / "tmp"
|
|||
SOURCE_ROOT = METAGPT_ROOT / "metagpt"
|
||||
PROMPT_PATH = SOURCE_ROOT / "prompts"
|
||||
SKILL_DIRECTORY = SOURCE_ROOT / "skills"
|
||||
TOOL_SCHEMA_PATH = METAGPT_ROOT / "metagpt/tools/schemas"
|
||||
TOOL_LIBS_PATH = METAGPT_ROOT / "metagpt/tools/libs"
|
||||
|
||||
# REAL CONSTS
|
||||
|
||||
|
|
|
|||
128
metagpt/prompts/ci/ml_action.py
Normal file
128
metagpt/prompts/ci/ml_action.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/24 15:43
|
||||
# @Author : lidanyang
|
||||
# @File : ml_action
|
||||
# @Desc :
|
||||
UPDATE_DATA_COLUMNS = """
|
||||
# Background
|
||||
Keep dataset column information updated before model train.
|
||||
## Done Tasks
|
||||
```python
|
||||
{history_code}
|
||||
```end
|
||||
|
||||
# Task
|
||||
Update and print the dataset's column information only if the train or test data has changed. Use the following code:
|
||||
```python
|
||||
from metagpt.tools.libs.data_preprocess import get_column_info
|
||||
|
||||
column_info = get_column_info(df)
|
||||
print("column_info")
|
||||
print(column_info)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Use the DataFrame variable from 'Done Tasks' in place of df.
|
||||
- Import `get_column_info` only if it's not already imported.
|
||||
"""
|
||||
|
||||
PRINT_DATA_COLUMNS = {
|
||||
"name": "print_column_info",
|
||||
"description": "Print the latest column information after 'Done Tasks' code if first read or data changed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to be added to a new cell in jupyter.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
|
||||
ML_COMMON_PROMPT = """
|
||||
# Background
|
||||
As a data scientist, you need to help user to achieve their goal [{user_requirement}] step-by-step in an continuous Jupyter notebook.
|
||||
|
||||
## Done Tasks
|
||||
```python
|
||||
{history_code}
|
||||
```end
|
||||
|
||||
## Current Task
|
||||
{current_task}
|
||||
|
||||
# Latest Data Info
|
||||
Latest data info after previous tasks:
|
||||
{column_info}
|
||||
|
||||
# Task
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc.
|
||||
Specifically, {tool_type_usage_prompt}
|
||||
"""
|
||||
|
||||
USE_NO_TOOLS_EXAMPLE = """
|
||||
# Output Example:
|
||||
when current task is "train a lightgbm model on training data", the code can be like:
|
||||
```python
|
||||
# Step 1: check data type and convert to numeric
|
||||
obj_cols = train.select_dtypes(include='object').columns.tolist()
|
||||
|
||||
for col in obj_cols:
|
||||
encoder = LabelEncoder()
|
||||
train[col] = encoder.fit_transform(train[col].unique().tolist() + ['unknown'])
|
||||
test[col] = test[col].apply(lambda x: x if x in encoder.classes_ else 'unknown')
|
||||
test[col] = encoder.transform(test[col])
|
||||
|
||||
# Step 2: train lightgbm model
|
||||
model = LGBMClassifier()
|
||||
model.fit(train, y_train)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
"""
|
||||
|
||||
USE_TOOLS_EXAMPLE = """
|
||||
# Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
# Available Tools:
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{tool_schemas}
|
||||
|
||||
# Output Example:
|
||||
when current task is "do data preprocess, like fill missing value, handle outliers, etc.", the code can be like:
|
||||
```python
|
||||
# Step 1: fill missing value
|
||||
# Tools used: ['FillMissingValue']
|
||||
from metagpt.tools.libs.data_preprocess import FillMissingValue
|
||||
|
||||
train_processed = train.copy()
|
||||
test_processed = test.copy()
|
||||
num_cols = train_processed.select_dtypes(include='number').columns.tolist()
|
||||
if 'label' in num_cols:
|
||||
num_cols.remove('label')
|
||||
fill_missing_value = FillMissingValue(features=num_cols, strategy='mean')
|
||||
fill_missing_value.fit(train_processed)
|
||||
train_processed = fill_missing_value.transform(train_processed)
|
||||
test_processed = fill_missing_value.transform(test_processed)
|
||||
|
||||
# Step 2: handle outliers
|
||||
for col in num_cols:
|
||||
low, high = train_processed[col].quantile([0.01, 0.99])
|
||||
train_processed[col] = train_processed[col].clip(low, high)
|
||||
test_processed[col] = test_processed[col].clip(low, high)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
- Always prioritize using pre-defined tools for the same functionality.
|
||||
- Always copy the DataFrame before processing it and use the copy to process.
|
||||
"""
|
||||
|
||||
ML_GENERATE_CODE_PROMPT = ML_COMMON_PROMPT + USE_NO_TOOLS_EXAMPLE
|
||||
ML_TOOL_USAGE_PROMPT = ML_COMMON_PROMPT + USE_TOOLS_EXAMPLE
|
||||
93
metagpt/prompts/ci/write_analysis_code.py
Normal file
93
metagpt/prompts/ci/write_analysis_code.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
ASSIGN_TASK_TYPE_PROMPT = """
|
||||
Please assign a task type to each task in the list below from the given categories:
|
||||
{task_info}
|
||||
|
||||
## All Task Type:
|
||||
{task_type_desc}
|
||||
"""
|
||||
|
||||
ASSIGN_TASK_TYPE_CONFIG = {
|
||||
"name": "assign_task_type",
|
||||
"description": "Assign task type to each task by order.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_type": {
|
||||
"type": "array",
|
||||
"description": "List of task type. The length should as long as task list",
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["task_type"],
|
||||
},
|
||||
}
|
||||
|
||||
TOOL_RECOMMENDATION_PROMPT = """
|
||||
## User Requirement:
|
||||
{current_task}
|
||||
|
||||
## Task
|
||||
Recommend up to five tools from 'Available Tools' that can help solve the 'User Requirement'.
|
||||
|
||||
## Available Tools:
|
||||
{available_tools}
|
||||
|
||||
## Tool Selection and Instructions:
|
||||
- Select tools most relevant to completing the 'User Requirement'.
|
||||
- If you believe that no tools are suitable, indicate with an empty list.
|
||||
- Only list the names of the tools, not the full schema of each tool.
|
||||
- Ensure selected tools are listed in 'Available Tools'.
|
||||
"""
|
||||
|
||||
SELECT_FUNCTION_TOOLS = {
|
||||
"name": "select_function_tools",
|
||||
"description": "For current task, select suitable tools for it.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recommend_tools": {
|
||||
"type": "array",
|
||||
"description": "List of tool names. Empty list if no tool is suitable.",
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["recommend_tools"],
|
||||
},
|
||||
}
|
||||
|
||||
CODE_GENERATOR_WITH_TOOLS = {
|
||||
"name": "add_subtask_code",
|
||||
"description": "Add new code cell of current task to the end of an active Jupyter notebook.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to be added to a new cell in jupyter.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
|
||||
TOOL_USAGE_PROMPT = """
|
||||
# Instruction
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from finished tasks, such as repeated import of packages, reading data, etc.
|
||||
Specifically, {tool_type_usage_prompt}
|
||||
|
||||
# Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
# Available Tools (can be empty):
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool first.
|
||||
{tool_schemas}
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
- Always prioritize using pre-defined tools for the same functionality.
|
||||
"""
|
||||
46
metagpt/prompts/tool_types.py
Normal file
46
metagpt/prompts/tool_types.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
# Prompt for using tools of "data_preprocess" type
|
||||
DATA_PREPROCESS_PROMPT = """
|
||||
The current task is about data preprocessing, please note the following:
|
||||
- Monitor data types per column, applying appropriate methods.
|
||||
- Ensure operations are on existing dataset columns.
|
||||
- Avoid writing processed data to files.
|
||||
- Avoid any change to label column, such as standardization, etc.
|
||||
- Prefer alternatives to one-hot encoding for categorical data.
|
||||
- Only encode or scale necessary columns to allow for potential feature-specific engineering tasks (like time_extract, binning, extraction, etc.) later.
|
||||
- Each step do data preprocessing to train, must do same for test separately at the same time.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "feature_engineering" type
|
||||
FEATURE_ENGINEERING_PROMPT = """
|
||||
The current task is about feature engineering. when performing it, please adhere to the following principles:
|
||||
- Generate as diverse features as possible to improve the model's performance step-by-step.
|
||||
- Use available feature engineering tools if they are potential impactful.
|
||||
- Avoid creating redundant or excessively numerous features in one step.
|
||||
- Exclude ID columns from feature generation and remove them.
|
||||
- Each feature engineering operation performed on the train set must also applies to the test separately at the same time.
|
||||
- Avoid using the label column to create features, except for cat encoding.
|
||||
- Use the data from previous task result if exist, do not mock or reload data yourself.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "model_train" type
|
||||
MODEL_TRAIN_PROMPT = """
|
||||
The current task is about training a model, please ensure high performance:
|
||||
- Keep in mind that your user prioritizes results and is highly focused on model performance. So, when needed, feel free to use models of any complexity to improve effectiveness, such as XGBoost, CatBoost, etc.
|
||||
- If non-numeric columns exist, perform label encode together with all steps.
|
||||
- Use the data from previous task result directly, do not mock or reload data yourself.
|
||||
- Set suitable hyperparameters for the model, make metrics as high as possible.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "model_evaluate" type
|
||||
MODEL_EVALUATE_PROMPT = """
|
||||
The current task is about evaluating a model, please note the following:
|
||||
- Ensure that the evaluated data is same processed as the training data. If not, remember use object in 'Done Tasks' to transform the data.
|
||||
- Use trained model from previous task result directly, do not mock or reload model yourself.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "vision" type
|
||||
IMAGE2WEBPAGE_PROMPT = """
|
||||
The current task is about converting image into webpage code. please note the following:
|
||||
- Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow.
|
||||
- Save webpages: Be sure to use the save method provided.
|
||||
"""
|
||||
|
|
@ -167,4 +167,12 @@ class BaseLLM(ABC):
|
|||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"])
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ from tenacity import (
|
|||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import decode_image
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -147,37 +147,41 @@ class OpenAILLM(BaseLLM):
|
|||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {
|
||||
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
|
||||
"tool_choice": GENERAL_TOOL_CHOICE,
|
||||
}
|
||||
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
# 全部转成list
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# 转成list[dict]
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
messages = self._process_message(messages)
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
if isinstance(messages, list):
|
||||
messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages]
|
||||
return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages]
|
||||
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages.to_dict()]
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!"
|
||||
)
|
||||
return messages
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
async def aask_code(self, messages: list[dict], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
|
|
@ -187,18 +191,37 @@ class OpenAILLM(BaseLLM):
|
|||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
@handle_exception
|
||||
# @handle_exception
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:param dict rsp: same as in self.get_choice_function(rsp)
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
message = rsp.choices[0].message
|
||||
if (
|
||||
message.tool_calls is not None
|
||||
and message.tool_calls[0].function is not None
|
||||
and message.tool_calls[0].function.arguments is not None
|
||||
):
|
||||
# reponse is code
|
||||
return json.loads(message.tool_calls[0].function.arguments, strict=False)
|
||||
elif message.tool_calls is None and message.content is not None:
|
||||
# reponse is code, fix openai tools_call respond bug,
|
||||
# The response content is `code``, but it appears in the content instead of the arguments.
|
||||
code_formats = "```"
|
||||
if message.content.startswith(code_formats) and message.content.endswith(code_formats):
|
||||
code = CodeParser.parse_code(None, message.content)
|
||||
return {"language": "python", "code": code}
|
||||
# reponse is message
|
||||
return {"language": "markdown", "code": self.get_choice_text(rsp)}
|
||||
else:
|
||||
logger.error(f"Failed to parse \n {rsp}\n")
|
||||
raise Exception(f"Failed to parse \n {rsp}\n")
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
|
|||
89
metagpt/roles/ci/code_interpreter.py
Normal file
89
metagpt/roles/ci/code_interpreter.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.ci.ask_review import ReviewConst
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message, Task, TaskResult
|
||||
|
||||
|
||||
class CodeInterpreter(Role):
|
||||
name: str = "Charlie"
|
||||
profile: str = "CodeInterpreter"
|
||||
auto_run: bool = True
|
||||
use_tools: bool = False
|
||||
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
|
||||
tools: list[str] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_run=True,
|
||||
use_tools=False,
|
||||
tools=[],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(auto_run=auto_run, use_tools=use_tools, tools=tools, **kwargs)
|
||||
self._set_react_mode(react_mode="plan_and_act", auto_run=auto_run, use_tools=use_tools)
|
||||
if use_tools and tools:
|
||||
from metagpt.tools.tool_registry import (
|
||||
validate_tool_names, # import upon use
|
||||
)
|
||||
|
||||
self.tools = validate_tool_names(tools)
|
||||
logger.info(f"will only use {self.tools} as tools")
|
||||
|
||||
@property
|
||||
def working_memory(self):
|
||||
return self.rc.working_memory
|
||||
|
||||
async def _act_on_task(self, current_task: Task) -> TaskResult:
|
||||
code, result, is_success = await self._write_and_exec_code()
|
||||
task_result = TaskResult(code=code, result=result, is_success=is_success)
|
||||
return task_result
|
||||
|
||||
async def _write_and_exec_code(self, max_retry: int = 3):
|
||||
counter = 0
|
||||
success = False
|
||||
|
||||
while not success and counter < max_retry:
|
||||
### write code ###
|
||||
code, cause_by = await self._write_code()
|
||||
|
||||
self.working_memory.add(Message(content=code["code"], role="assistant", cause_by=cause_by))
|
||||
|
||||
### execute code ###
|
||||
result, success = await self.execute_code.run(**code)
|
||||
print(result)
|
||||
|
||||
self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))
|
||||
|
||||
### process execution result ###
|
||||
counter += 1
|
||||
|
||||
if not success and counter >= max_retry:
|
||||
logger.info("coding failed!")
|
||||
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
|
||||
if ReviewConst.CHANGE_WORDS[0] in review:
|
||||
counter = 0 # redo the task again with help of human suggestions
|
||||
|
||||
py_code = (
|
||||
code["code"] if code.get("language") == "python" else ""
|
||||
) # use python code as final code; for markdown, return the rendered result instead of the code itself
|
||||
|
||||
return py_code, result, success
|
||||
|
||||
async def _write_code(self):
|
||||
todo = WriteCodeWithoutTools() if not self.use_tools else WriteCodeWithTools(selected_tools=self.tools)
|
||||
logger.info(f"ready to {todo.name}")
|
||||
|
||||
context = self.planner.get_useful_memories()
|
||||
# print(*context, sep="\n***\n")
|
||||
code = await todo.run(context=context, plan=self.planner.plan, temperature=0.0)
|
||||
|
||||
return code, todo
|
||||
64
metagpt/roles/ci/ml_engineer.py
Normal file
64
metagpt/roles/ci/ml_engineer.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from metagpt.actions.ci.debug_code import DebugCode
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.ml_action import UpdateDataColumns, WriteCodeWithToolsML
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class MLEngineer(CodeInterpreter):
|
||||
name: str = "Mark"
|
||||
profile: str = "MLEngineer"
|
||||
debug_context: list = []
|
||||
latest_code: str = ""
|
||||
|
||||
async def _write_code(self):
|
||||
if not self.use_tools:
|
||||
return await super()._write_code()
|
||||
|
||||
# In a trial and errors settings, check whether this is our first attempt to tackle the task. If there is no code execution before, then it is.
|
||||
is_first_trial = any_to_str(ExecuteNbCode) not in [msg.cause_by for msg in self.working_memory.get()]
|
||||
|
||||
if is_first_trial:
|
||||
# For the first trial, write task code from scratch
|
||||
column_info = await self._update_data_columns()
|
||||
|
||||
logger.info("Write code with tools")
|
||||
tool_context, code = await WriteCodeWithToolsML(selected_tools=self.tools).run(
|
||||
context=[], # context assembled inside the Action
|
||||
plan=self.planner.plan,
|
||||
column_info=column_info,
|
||||
)
|
||||
self.debug_context = tool_context
|
||||
cause_by = WriteCodeWithToolsML
|
||||
|
||||
else:
|
||||
# Previous trials resulted in error, debug and rewrite the code
|
||||
logger.warning("We got a bug, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=self.debug_context,
|
||||
)
|
||||
logger.info(f"new code \n{code}")
|
||||
cause_by = DebugCode
|
||||
|
||||
self.latest_code = code["code"]
|
||||
|
||||
return code, cause_by
|
||||
|
||||
async def _update_data_columns(self):
|
||||
current_task = self.planner.plan.current_task
|
||||
if current_task.task_type not in [
|
||||
ToolType.DATA_PREPROCESS.type_name,
|
||||
ToolType.FEATURE_ENGINEERING.type_name,
|
||||
ToolType.MODEL_TRAIN.type_name,
|
||||
]:
|
||||
return ""
|
||||
logger.info("Check columns in updated data")
|
||||
code = await UpdateDataColumns().run(self.planner.plan)
|
||||
success = False
|
||||
result, success = await self.execute_code.run(**code)
|
||||
print(result)
|
||||
return result if success else ""
|
||||
|
|
@ -35,6 +35,7 @@ from metagpt.logs import logger
|
|||
from metagpt.memory import Memory
|
||||
from metagpt.provider import HumanProvider
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
|
@ -97,6 +98,7 @@ class RoleContext(BaseModel):
|
|||
) # Message Buffer with Asynchronous Updates
|
||||
memory: Memory = Field(default_factory=Memory)
|
||||
# long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
|
||||
working_memory: Memory = Field(default_factory=Memory)
|
||||
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
|
||||
todo: Action = Field(default=None, exclude=True)
|
||||
watch: set[str] = Field(default_factory=set)
|
||||
|
|
@ -152,6 +154,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
addresses: set[str] = set()
|
||||
planner: Planner = Field(default_factory=Planner)
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
|
|
@ -280,7 +283,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1):
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions.
|
||||
|
||||
|
|
@ -300,6 +303,10 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self.rc.react_mode = react_mode
|
||||
if react_mode == RoleReactMode.REACT:
|
||||
self.rc.max_react_loop = max_react_loop
|
||||
elif react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
self.planner = Planner(
|
||||
goal=self.goal, working_memory=self.rc.working_memory, auto_run=auto_run, use_tools=use_tools
|
||||
)
|
||||
|
||||
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
|
||||
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
|
||||
|
|
@ -476,8 +483,41 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
|
||||
async def _plan_and_act(self) -> Message:
|
||||
"""first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""
|
||||
# TODO: to be implemented
|
||||
return Message(content="")
|
||||
|
||||
# create initial plan and update it until confirmation
|
||||
goal = self.rc.memory.get()[-1].content # retreive latest user requirement
|
||||
await self.planner.update_plan(goal=goal)
|
||||
|
||||
# take on tasks until all finished
|
||||
while self.planner.current_task:
|
||||
task = self.planner.current_task
|
||||
logger.info(f"ready to take on task {task}")
|
||||
|
||||
# take on current task
|
||||
task_result = await self._act_on_task(task)
|
||||
|
||||
# process the result, such as reviewing, confirming, plan updating
|
||||
await self.planner.process_task_result(task_result)
|
||||
|
||||
rsp = self.planner.get_useful_memories()[0] # return the completed plan as a response
|
||||
|
||||
self.rc.memory.add(rsp) # add to persistent memory
|
||||
|
||||
return rsp
|
||||
|
||||
async def _act_on_task(self, current_task: Task) -> TaskResult:
|
||||
"""Taking specific action to handle one task in plan
|
||||
|
||||
Args:
|
||||
current_task (Task): current task to take on
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Specific Role must implement this method if expected to use planner
|
||||
|
||||
Returns:
|
||||
TaskResult: Result from the actions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def react(self) -> Message:
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message"""
|
||||
|
|
|
|||
|
|
@ -330,6 +330,200 @@ class AIMessage(Message):
|
|||
super().__init__(content=content, role="assistant")
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
task_id: str = ""
|
||||
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
|
||||
instruction: str = ""
|
||||
task_type: str = ""
|
||||
code: str = ""
|
||||
result: str = ""
|
||||
is_success: bool = False
|
||||
is_finished: bool = False
|
||||
|
||||
def reset(self):
|
||||
self.code = ""
|
||||
self.result = ""
|
||||
self.is_success = False
|
||||
self.is_finished = False
|
||||
|
||||
def update_task_result(self, task_result: TaskResult):
|
||||
self.code = task_result.code
|
||||
self.result = task_result.result
|
||||
self.is_success = task_result.is_success
|
||||
|
||||
|
||||
class TaskResult(BaseModel):
|
||||
"""Result of taking a task, with result and is_success required to be filled"""
|
||||
|
||||
code: str = ""
|
||||
result: str
|
||||
is_success: bool
|
||||
|
||||
|
||||
class Plan(BaseModel):
|
||||
goal: str
|
||||
context: str = ""
|
||||
tasks: list[Task] = []
|
||||
task_map: dict[str, Task] = {}
|
||||
current_task_id: str = ""
|
||||
|
||||
def _topological_sort(self, tasks: list[Task]):
|
||||
task_map = {task.task_id: task for task in tasks}
|
||||
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
|
||||
sorted_tasks = []
|
||||
visited = set()
|
||||
|
||||
def visit(task_id):
|
||||
if task_id in visited:
|
||||
return
|
||||
visited.add(task_id)
|
||||
for dependent_id in dependencies.get(task_id, []):
|
||||
visit(dependent_id)
|
||||
sorted_tasks.append(task_map[task_id])
|
||||
|
||||
for task in tasks:
|
||||
visit(task.task_id)
|
||||
|
||||
return sorted_tasks
|
||||
|
||||
def add_tasks(self, tasks: list[Task]):
|
||||
"""
|
||||
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
|
||||
|
||||
This method performs two primary functions based on the current state of the task list:
|
||||
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
|
||||
correct execution order based on dependencies, and sets these as the current tasks.
|
||||
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
|
||||
any common prefix of tasks (based on task_id and instruction) and appends the remainder
|
||||
of the new tasks. The current task is updated to the first unfinished task in this merged list.
|
||||
|
||||
Args:
|
||||
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
|
||||
|
||||
Returns:
|
||||
None: The method updates the internal state of the plan but does not return anything.
|
||||
"""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
# Topologically sort the new tasks to ensure correct dependency order
|
||||
new_tasks = self._topological_sort(tasks)
|
||||
|
||||
if not self.tasks:
|
||||
# If there are no existing tasks, set the new tasks as the current tasks
|
||||
self.tasks = new_tasks
|
||||
|
||||
else:
|
||||
# Find the length of the common prefix between existing and new tasks
|
||||
prefix_length = 0
|
||||
for old_task, new_task in zip(self.tasks, new_tasks):
|
||||
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
|
||||
break
|
||||
prefix_length += 1
|
||||
|
||||
# Combine the common prefix with the remainder of the new tasks
|
||||
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
|
||||
self.tasks = final_tasks
|
||||
|
||||
# Update current_task_id to the first unfinished task in the merged list
|
||||
self._update_current_task()
|
||||
|
||||
# Update the task map for quick access to tasks by ID
|
||||
self.task_map = {task.task_id: task for task in self.tasks}
|
||||
|
||||
def reset_task(self, task_id: str):
|
||||
"""
|
||||
Clear code and result of the task based on task_id, and set the task as unfinished.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to be reset.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if task_id in self.task_map:
|
||||
task = self.task_map[task_id]
|
||||
task.reset()
|
||||
|
||||
def replace_task(self, new_task: Task):
|
||||
"""
|
||||
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
|
||||
|
||||
Args:
|
||||
new_task (Task): The new task that will replace an existing one.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert new_task.task_id in self.task_map
|
||||
# Replace the task in the task map and the task list
|
||||
self.task_map[new_task.task_id] = new_task
|
||||
for i, task in enumerate(self.tasks):
|
||||
if task.task_id == new_task.task_id:
|
||||
self.tasks[i] = new_task
|
||||
break
|
||||
|
||||
# Reset dependent tasks
|
||||
for task in self.tasks:
|
||||
if new_task.task_id in task.dependent_task_ids:
|
||||
self.reset_task(task.task_id)
|
||||
|
||||
def append_task(self, new_task: Task):
|
||||
"""
|
||||
Append a new task to the end of existing task sequences
|
||||
|
||||
Args:
|
||||
new_task (Task): The new task to be appended to the existing task sequence
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
|
||||
|
||||
assert all(
|
||||
[self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]
|
||||
), "New task has unknown dependencies"
|
||||
|
||||
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
|
||||
self.tasks.append(new_task)
|
||||
self.task_map[new_task.task_id] = new_task
|
||||
self._update_current_task()
|
||||
|
||||
def has_task_id(self, task_id: str) -> bool:
|
||||
return task_id in self.task_map
|
||||
|
||||
def _update_current_task(self):
|
||||
current_task_id = ""
|
||||
for task in self.tasks:
|
||||
if not task.is_finished:
|
||||
current_task_id = task.task_id
|
||||
break
|
||||
self.current_task_id = current_task_id # all tasks finished
|
||||
|
||||
@property
|
||||
def current_task(self) -> Task:
|
||||
"""Find current task to execute
|
||||
|
||||
Returns:
|
||||
Task: the current task to be executed
|
||||
"""
|
||||
return self.task_map.get(self.current_task_id, None)
|
||||
|
||||
def finish_current_task(self):
|
||||
"""Finish current task, set Task.is_finished=True, set current task to next task"""
|
||||
if self.current_task_id:
|
||||
self.current_task.is_finished = True
|
||||
self._update_current_task() # set to next task
|
||||
|
||||
def get_finished_tasks(self) -> list[Task]:
|
||||
"""return all finished tasks in correct linearized order
|
||||
|
||||
Returns:
|
||||
list[Task]: list of finished tasks
|
||||
"""
|
||||
return [task for task in self.tasks if task.is_finished]
|
||||
|
||||
|
||||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
|
|
|
|||
139
metagpt/strategy/planner.py
Normal file
139
metagpt/strategy/planner.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.ci.ask_review import AskReview, ReviewConst
|
||||
from metagpt.actions.ci.write_plan import (
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
update_plan_from_rsp,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.schema import Message, Plan, Task, TaskResult
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
## User Requirement
|
||||
{user_requirement}
|
||||
## Context
|
||||
{context}
|
||||
## Current Plan
|
||||
{tasks}
|
||||
## Current Task
|
||||
{current_task}
|
||||
"""
|
||||
|
||||
|
||||
class Planner(BaseModel):
|
||||
plan: Plan
|
||||
working_memory: Memory = Field(
|
||||
default_factory=Memory
|
||||
) # memory for working on each task, discarded each time a task is done
|
||||
auto_run: bool = False
|
||||
use_tools: bool = False
|
||||
|
||||
def __init__(self, goal: str = "", plan: Plan = None, **kwargs):
|
||||
plan = plan or Plan(goal=goal)
|
||||
super().__init__(plan=plan, **kwargs)
|
||||
|
||||
@property
|
||||
def current_task(self):
|
||||
return self.plan.current_task
|
||||
|
||||
@property
|
||||
def current_task_id(self):
|
||||
return self.plan.current_task_id
|
||||
|
||||
async def update_plan(self, goal: str = "", max_tasks: int = 3, max_retries: int = 3):
|
||||
if goal:
|
||||
self.plan = Plan(goal=goal)
|
||||
|
||||
plan_confirmed = False
|
||||
while not plan_confirmed:
|
||||
context = self.get_useful_memories()
|
||||
rsp = await WritePlan().run(context, max_tasks=max_tasks, use_tools=self.use_tools)
|
||||
self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
|
||||
|
||||
# precheck plan before asking reviews
|
||||
is_plan_valid, error = precheck_update_plan_from_rsp(rsp, self.plan)
|
||||
if not is_plan_valid and max_retries > 0:
|
||||
error_msg = f"The generated plan is not valid with error: {error}, try regenerating, remember to generate either the whole plan or the single changed task only"
|
||||
logger.warning(error_msg)
|
||||
self.working_memory.add(Message(content=error_msg, role="assistant", cause_by=WritePlan))
|
||||
max_retries -= 1
|
||||
continue
|
||||
|
||||
_, plan_confirmed = await self.ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
update_plan_from_rsp(rsp=rsp, current_plan=self.plan)
|
||||
|
||||
self.working_memory.clear()
|
||||
|
||||
async def process_task_result(self, task_result: TaskResult):
|
||||
# ask for acceptance, users can other refuse and change tasks in the plan
|
||||
review, task_result_confirmed = await self.ask_review(task_result)
|
||||
|
||||
if task_result_confirmed:
|
||||
# tick off this task and record progress
|
||||
await self.confirm_task(self.current_task, task_result, review)
|
||||
|
||||
elif "redo" in review:
|
||||
# Ask the Role to redo this task with help of review feedback,
|
||||
# useful when the code run is successful but the procedure or result is not what we want
|
||||
pass # simply pass, not confirming the result
|
||||
|
||||
else:
|
||||
# update plan according to user's feedback and to take on changed tasks
|
||||
await self.update_plan()
|
||||
|
||||
async def ask_review(
|
||||
self,
|
||||
task_result: TaskResult = None,
|
||||
auto_run: bool = None,
|
||||
trigger: str = ReviewConst.TASK_REVIEW_TRIGGER,
|
||||
review_context_len: int = 5,
|
||||
):
|
||||
"""
|
||||
Ask to review the task result, reviewer needs to provide confirmation or request change.
|
||||
If human confirms the task result, then we deem the task completed, regardless of whether the code run succeeds;
|
||||
if auto mode, then the code run has to succeed for the task to be considered completed.
|
||||
"""
|
||||
auto_run = auto_run or self.auto_run
|
||||
if not auto_run:
|
||||
context = self.get_useful_memories()
|
||||
review, confirmed = await AskReview().run(
|
||||
context=context[-review_context_len:], plan=self.plan, trigger=trigger
|
||||
)
|
||||
if not confirmed:
|
||||
self.working_memory.add(Message(content=review, role="user", cause_by=AskReview))
|
||||
return review, confirmed
|
||||
confirmed = task_result.is_success if task_result else True
|
||||
return "", confirmed
|
||||
|
||||
async def confirm_task(self, task: Task, task_result: TaskResult, review: str):
|
||||
task.update_task_result(task_result=task_result)
|
||||
self.plan.finish_current_task()
|
||||
self.working_memory.clear()
|
||||
|
||||
confirmed_and_more = (
|
||||
ReviewConst.CONTINUE_WORDS[0] in review.lower() and review.lower() not in ReviewConst.CONTINUE_WORDS[0]
|
||||
) # "confirm, ... (more content, such as changing downstream tasks)"
|
||||
if confirmed_and_more:
|
||||
self.working_memory.add(Message(content=review, role="user", cause_by=AskReview))
|
||||
await self.update_plan(review)
|
||||
|
||||
def get_useful_memories(self, task_exclude_field=None) -> list[Message]:
|
||||
"""find useful memories only to reduce context length and improve performance"""
|
||||
user_requirement = self.plan.goal
|
||||
context = self.plan.context
|
||||
tasks = [task.dict(exclude=task_exclude_field) for task in self.plan.tasks]
|
||||
tasks = json.dumps(tasks, indent=4, ensure_ascii=False)
|
||||
current_task = self.plan.current_task.json() if self.plan.current_task else {}
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=user_requirement, context=context, tasks=tasks, current_task=current_task
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
return context_msg + self.working_memory.get()
|
||||
|
|
@ -6,8 +6,11 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from metagpt.tools import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
|
||||
_ = libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
|
|
|||
15
metagpt/tools/libs/__init__.py
Normal file
15
metagpt/tools/libs/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/16 16:32
|
||||
# @Author : lidanyang
|
||||
# @File : __init__.py
|
||||
# @Desc :
|
||||
from metagpt.tools.libs import (
|
||||
data_preprocess,
|
||||
feature_engineering,
|
||||
sd_engine,
|
||||
gpt_v_generator,
|
||||
web_scraping,
|
||||
)
|
||||
|
||||
_ = data_preprocess, feature_engineering, sd_engine, gpt_v_generator, web_scraping # Avoid pre-commit error
|
||||
249
metagpt/tools/libs/data_preprocess.py
Normal file
249
metagpt/tools/libs/data_preprocess.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.preprocessing import (
|
||||
LabelEncoder,
|
||||
MaxAbsScaler,
|
||||
MinMaxScaler,
|
||||
OneHotEncoder,
|
||||
OrdinalEncoder,
|
||||
RobustScaler,
|
||||
StandardScaler,
|
||||
)
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.DATA_PREPROCESS.type_name
|
||||
|
||||
|
||||
class MLProcess:
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit a model to be used in subsequent transform.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Fit and transform the input DataFrame.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
self.fit(df)
|
||||
return self.transform(df)
|
||||
|
||||
|
||||
class DataPreprocessTool(MLProcess):
|
||||
"""
|
||||
Completing a data preprocessing operation.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
"""
|
||||
self.features = features
|
||||
self.model = None # to be filled by specific subclass Tool
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
if len(self.features) == 0:
|
||||
return
|
||||
self.model.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.model.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class FillMissingValue(DataPreprocessTool):
|
||||
"""
|
||||
Completing missing values with simple strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list, strategy: str = "mean", fill_value=None):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
|
||||
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.features = features
|
||||
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MinMaxScale(DataPreprocessTool):
|
||||
"""
|
||||
Transform features by scaling each feature to a range, which is (0, 1).
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = MinMaxScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class StandardScale(DataPreprocessTool):
|
||||
"""
|
||||
Standardize features by removing the mean and scaling to unit variance.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = StandardScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MaxAbsScale(DataPreprocessTool):
|
||||
"""
|
||||
Scale each feature by its maximum absolute value.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = MaxAbsScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class RobustScale(DataPreprocessTool):
|
||||
"""
|
||||
Apply the RobustScaler to scale features using statistics that are robust to outliers.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = RobustScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OrdinalEncode(DataPreprocessTool):
|
||||
"""
|
||||
Encode categorical features as ordinal integers.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = OrdinalEncoder()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OneHotEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
self.features = features
|
||||
self.model = OneHotEncoder(handle_unknown="ignore", sparse=False)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
ts_data = self.model.transform(df[self.features])
|
||||
new_columns = self.model.get_feature_names_out(self.features)
|
||||
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
|
||||
new_df = df.drop(self.features, axis=1)
|
||||
new_df = pd.concat([new_df, ts_data], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class LabelEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply label encoding to specified categorical columns in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Categorical columns to be label encoded.
|
||||
"""
|
||||
self.features = features
|
||||
self.le_encoders = []
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
if len(self.features) == 0:
|
||||
return
|
||||
for col in self.features:
|
||||
le = LabelEncoder().fit(df[col].astype(str).unique().tolist() + ["unknown"])
|
||||
self.le_encoders.append(le)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
for i in range(len(self.features)):
|
||||
data_list = df[self.features[i]].astype(str).tolist()
|
||||
for unique_item in np.unique(df[self.features[i]].astype(str)):
|
||||
if unique_item not in self.le_encoders[i].classes_:
|
||||
data_list = ["unknown" if x == unique_item else x for x in data_list]
|
||||
new_df[self.features[i]] = self.le_encoders[i].transform(data_list)
|
||||
return new_df
|
||||
|
||||
|
||||
def get_column_info(df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
Each key corresponds to a list of column names belonging to that category.
|
||||
"""
|
||||
column_info = {
|
||||
"Category": [],
|
||||
"Numeric": [],
|
||||
"Datetime": [],
|
||||
"Others": [],
|
||||
}
|
||||
for col in df.columns:
|
||||
data_type = str(df[col].dtype).replace("dtype('", "").replace("')", "")
|
||||
if data_type.startswith("object"):
|
||||
column_info["Category"].append(col)
|
||||
elif data_type.startswith("int") or data_type.startswith("float"):
|
||||
column_info["Numeric"].append(col)
|
||||
elif data_type.startswith("datetime"):
|
||||
column_info["Datetime"].append(col)
|
||||
else:
|
||||
column_info["Others"].append(col)
|
||||
|
||||
if len(json.dumps(column_info)) > 2000:
|
||||
column_info["Numeric"] = column_info["Numeric"][0:5] + ["Too many cols, omission here..."]
|
||||
return column_info
|
||||
435
metagpt/tools/libs/feature_engineering.py
Normal file
435
metagpt/tools/libs/feature_engineering.py
Normal file
|
|
@ -0,0 +1,435 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/17 10:33
|
||||
# @Author : lidanyang
|
||||
# @File : feature_engineering.py
|
||||
# @Desc : Feature Engineering Tools
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
# import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
from pandas.core.dtypes.common import is_object_dtype
|
||||
from sklearn.feature_selection import VarianceThreshold
|
||||
from sklearn.model_selection import KFold
|
||||
from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
|
||||
|
||||
from metagpt.tools.libs.data_preprocess import MLProcess
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.FEATURE_ENGINEERING.type_name
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class PolynomialExpansion(MLProcess):
|
||||
"""
|
||||
Add polynomial and interaction features from selected numeric columns to input DataFrame.
|
||||
"""
|
||||
|
||||
def __init__(self, cols: list, label_col: str, degree: int = 2):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
cols (list): Columns for polynomial expansion.
|
||||
label_col (str): Label column name.
|
||||
degree (int, optional): The degree of the polynomial features. Defaults to 2.
|
||||
"""
|
||||
self.cols = cols
|
||||
self.degree = degree
|
||||
self.label_col = label_col
|
||||
if self.label_col in self.cols:
|
||||
self.cols.remove(self.label_col)
|
||||
self.poly = PolynomialFeatures(degree=degree, include_bias=False)
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
if len(self.cols) == 0:
|
||||
return
|
||||
if len(self.cols) > 10:
|
||||
corr = df[self.cols + [self.label_col]].corr()
|
||||
corr = corr[self.label_col].abs().sort_values(ascending=False)
|
||||
self.cols = corr.index.tolist()[1:11]
|
||||
|
||||
self.poly.fit(df[self.cols].fillna(0))
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(self.cols) == 0:
|
||||
return df
|
||||
ts_data = self.poly.transform(df[self.cols].fillna(0))
|
||||
column_name = self.poly.get_feature_names_out(self.cols)
|
||||
ts_data = pd.DataFrame(ts_data, index=df.index, columns=column_name)
|
||||
new_df = df.drop(self.cols, axis=1)
|
||||
new_df = pd.concat([new_df, ts_data], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCount(MLProcess):
|
||||
"""
|
||||
Add value counts of a categorical column as new feature.
|
||||
"""
|
||||
|
||||
def __init__(self, col: str):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
col (str): Column for value counts.
|
||||
"""
|
||||
self.col = col
|
||||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
self.encoder_dict = df[self.col].value_counts().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_cnt"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Encode a categorical column by the mean of the label column, and adds the result as a new feature.
|
||||
"""
|
||||
|
||||
def __init__(self, col: str, label: str):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
col (str): Column to be mean encoded.
|
||||
label (str): Predicted label column.
|
||||
"""
|
||||
self.col = col
|
||||
self.label = label
|
||||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
self.encoder_dict = df.groupby(self.col)[self.label].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Add a new feature to the DataFrame by k-fold mean encoding of a categorical column using the label column.
|
||||
"""
|
||||
|
||||
def __init__(self, col: str, label: str, n_splits: int = 5, random_state: int = 2021):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
col (str): Column to be k-fold mean encoded.
|
||||
label (str): Predicted label column.
|
||||
n_splits (int, optional): Number of splits for K-fold. Defaults to 5.
|
||||
random_state (int, optional): Random seed. Defaults to 2021.
|
||||
"""
|
||||
self.col = col
|
||||
self.label = label
|
||||
self.n_splits = n_splits
|
||||
self.random_state = random_state
|
||||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
tmp = df.copy()
|
||||
kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
|
||||
|
||||
global_mean = tmp[self.label].mean()
|
||||
col_name = f"{self.col}_kf_target_mean"
|
||||
for trn_idx, val_idx in kf.split(tmp, tmp[self.label]):
|
||||
_trn, _val = tmp.iloc[trn_idx], tmp.iloc[val_idx]
|
||||
tmp.loc[tmp.index[val_idx], col_name] = _val[self.col].map(_trn.groupby(self.col)[self.label].mean())
|
||||
tmp[col_name].fillna(global_mean, inplace=True)
|
||||
self.encoder_dict = tmp.groupby(self.col)[col_name].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_kf_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCross(MLProcess):
|
||||
"""
|
||||
Add pairwise crossed features and convert them to numerical features.
|
||||
"""
|
||||
|
||||
def __init__(self, cols: list, max_cat_num: int = 100):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
cols (list): Columns to be pairwise crossed, at least 2 columns.
|
||||
max_cat_num (int, optional): Maximum unique categories per crossed feature. Defaults to 100.
|
||||
"""
|
||||
self.cols = cols
|
||||
self.max_cat_num = max_cat_num
|
||||
self.combs = []
|
||||
self.combs_map = {}
|
||||
|
||||
@staticmethod
|
||||
def _cross_two(comb, df):
|
||||
"""
|
||||
Cross two columns and convert them to numerical features.
|
||||
|
||||
Args:
|
||||
comb (tuple): The pair of columns to be crossed.
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
tuple: The new column name and the crossed feature map.
|
||||
"""
|
||||
new_col = f"{comb[0]}_{comb[1]}"
|
||||
new_col_combs = list(itertools.product(df[comb[0]].unique(), df[comb[1]].unique()))
|
||||
ll = list(range(len(new_col_combs)))
|
||||
comb_map = dict(zip(new_col_combs, ll))
|
||||
return new_col, comb_map
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
for col in self.cols:
|
||||
if df[col].nunique() > self.max_cat_num:
|
||||
self.cols.remove(col)
|
||||
self.combs = list(itertools.combinations(self.cols, 2))
|
||||
res = Parallel(n_jobs=4, require="sharedmem")(delayed(self._cross_two)(comb, df) for comb in self.combs)
|
||||
self.combs_map = dict(res)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
for comb in self.combs:
|
||||
new_col = f"{comb[0]}_{comb[1]}"
|
||||
_map = self.combs_map[new_col]
|
||||
new_df[new_col] = pd.Series(zip(new_df[comb[0]], new_df[comb[1]])).map(_map)
|
||||
# set the unknown value to a new number
|
||||
new_df[new_col].fillna(max(_map.values()) + 1, inplace=True)
|
||||
new_df[new_col] = new_df[new_col].astype(int)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GroupStat(MLProcess):
|
||||
"""
|
||||
Aggregate specified column in a DataFrame grouped by another column, adding new features named '<agg_col>_<agg_func>_by_<group_col>'.
|
||||
"""
|
||||
|
||||
def __init__(self, group_col: str, agg_col: str, agg_funcs: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
group_col (str): Column used for grouping.
|
||||
agg_col (str): Column on which aggregation is performed.
|
||||
agg_funcs (list): List of aggregation functions to apply, such as ['mean', 'std']. Each function must be supported by pandas.
|
||||
"""
|
||||
self.group_col = group_col
|
||||
self.agg_col = agg_col
|
||||
self.agg_funcs = agg_funcs
|
||||
self.group_df = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
group_df = df.groupby(self.group_col)[self.agg_col].agg(self.agg_funcs).reset_index()
|
||||
group_df.columns = [self.group_col] + [
|
||||
f"{self.agg_col}_{agg_func}_by_{self.group_col}" for agg_func in self.agg_funcs
|
||||
]
|
||||
self.group_df = group_df
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.merge(self.group_df, on=self.group_col, how="left")
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class SplitBins(MLProcess):
|
||||
"""
|
||||
Inplace binning of continuous data into intervals, returning integer-encoded bin identifiers directly.
|
||||
"""
|
||||
|
||||
def __init__(self, cols: list, strategy: str = "quantile"):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
cols (list): Columns to be binned inplace.
|
||||
strategy (str, optional): Strategy used to define the widths of the bins. Enum: ['quantile', 'uniform', 'kmeans']. Defaults to 'quantile'.
|
||||
"""
|
||||
self.cols = cols
|
||||
self.strategy = strategy
|
||||
self.encoder = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
self.encoder = KBinsDiscretizer(strategy=self.strategy, encode="ordinal")
|
||||
self.encoder.fit(df[self.cols].fillna(0))
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df.copy()
|
||||
new_df[self.cols] = self.encoder.transform(new_df[self.cols].fillna(0))
|
||||
return new_df
|
||||
|
||||
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
class ExtractTimeComps(MLProcess):
|
||||
"""
|
||||
Extract time components from a datetime column and add them as new features.
|
||||
"""
|
||||
|
||||
def __init__(self, time_col: str, time_comps: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
time_col (str): The name of the column containing time data.
|
||||
time_comps (list): List of time components to extract. Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend'].
|
||||
"""
|
||||
self.time_col = time_col
|
||||
self.time_comps = time_comps
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
pass
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
time_s = pd.to_datetime(df[self.time_col], errors="coerce")
|
||||
time_comps_df = pd.DataFrame()
|
||||
|
||||
if "year" in self.time_comps:
|
||||
time_comps_df["year"] = time_s.dt.year
|
||||
if "month" in self.time_comps:
|
||||
time_comps_df["month"] = time_s.dt.month
|
||||
if "day" in self.time_comps:
|
||||
time_comps_df["day"] = time_s.dt.day
|
||||
if "hour" in self.time_comps:
|
||||
time_comps_df["hour"] = time_s.dt.hour
|
||||
if "dayofweek" in self.time_comps:
|
||||
time_comps_df["dayofweek"] = time_s.dt.dayofweek + 1
|
||||
if "is_weekend" in self.time_comps:
|
||||
time_comps_df["is_weekend"] = time_s.dt.dayofweek.isin([5, 6]).astype(int)
|
||||
new_df = pd.concat([df, time_comps_df], axis=1)
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GeneralSelection(MLProcess):
|
||||
"""
|
||||
Drop all nan feats and feats with only one unique value.
|
||||
"""
|
||||
|
||||
def __init__(self, label_col: str):
|
||||
self.label_col = label_col
|
||||
self.feats = []
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
feats = [f for f in df.columns if f != self.label_col]
|
||||
for col in df.columns:
|
||||
if df[col].isnull().sum() / df.shape[0] == 1:
|
||||
feats.remove(col)
|
||||
|
||||
if df[col].nunique() == 1:
|
||||
feats.remove(col)
|
||||
|
||||
if df.loc[df[col] == np.inf].shape[0] != 0 or df.loc[df[col] == np.inf].shape[0] != 0:
|
||||
feats.remove(col)
|
||||
|
||||
if is_object_dtype(df[col]) and df[col].nunique() == df.shape[0]:
|
||||
feats.remove(col)
|
||||
|
||||
self.feats = feats
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df[self.feats + [self.label_col]]
|
||||
return new_df
|
||||
|
||||
|
||||
# skip for now because lgb is needed
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
class TreeBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on tree-based model and remove features with low importance.
|
||||
"""
|
||||
|
||||
def __init__(self, label_col: str, task_type: str):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
label_col (str): Label column name.
|
||||
task_type (str): Task type, 'cls' for classification, 'mcls' for multi-class classification, 'reg' for regression.
|
||||
"""
|
||||
self.label_col = label_col
|
||||
self.task_type = task_type
|
||||
self.feats = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
params = {
|
||||
"boosting_type": "gbdt",
|
||||
"objective": "binary",
|
||||
"learning_rate": 0.1,
|
||||
"num_leaves": 31,
|
||||
}
|
||||
|
||||
if self.task_type == "cls":
|
||||
params["objective"] = "binary"
|
||||
params["metric"] = "auc"
|
||||
elif self.task_type == "mcls":
|
||||
params["objective"] = "multiclass"
|
||||
params["num_class"] = df[self.label_col].nunique()
|
||||
params["metric"] = "auc_mu"
|
||||
elif self.task_type == "reg":
|
||||
params["objective"] = "regression"
|
||||
params["metric"] = "rmse"
|
||||
|
||||
num_cols = df.select_dtypes(include=np.number).columns.tolist()
|
||||
cols = [f for f in num_cols if f not in [self.label_col]]
|
||||
|
||||
dtrain = lgb.Dataset(df[cols], df[self.label_col])
|
||||
model = lgb.train(params, dtrain, num_boost_round=100)
|
||||
df_imp = pd.DataFrame({"feature_name": dtrain.feature_name, "importance": model.feature_importance("gain")})
|
||||
|
||||
df_imp.sort_values("importance", ascending=False, inplace=True)
|
||||
df_imp = df_imp[df_imp["importance"] > 0]
|
||||
self.feats = df_imp["feature_name"].tolist()
|
||||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on variance and remove features with low variance.
|
||||
"""
|
||||
|
||||
def __init__(self, label_col: str, threshold: float = 0):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
label_col (str): Label column name.
|
||||
threshold (float, optional): Threshold for variance. Defaults to 0.
|
||||
"""
|
||||
self.label_col = label_col
|
||||
self.threshold = threshold
|
||||
self.feats = None
|
||||
self.selector = VarianceThreshold(threshold=self.threshold)
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
num_cols = df.select_dtypes(include=np.number).columns.tolist()
|
||||
cols = [f for f in num_cols if f not in [self.label_col]]
|
||||
|
||||
self.selector.fit(df[cols])
|
||||
self.feats = df[cols].columns[self.selector.get_support(indices=True)].tolist()
|
||||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
177
metagpt/tools/libs/gpt_v_generator.py
Normal file
177
metagpt/tools/libs/gpt_v_generator.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/12
|
||||
@Author : mannaandpoem
|
||||
@File : gpt_v_generator.py
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
|
||||
|
||||
NOTE: The image does not have a commercial logo or copyright information. It is just a sketch image of the design.
|
||||
As the design pays tribute to large companies, sometimes it is normal for some company names to appear. Don't worry. """
|
||||
|
||||
GENERATE_PROMPT = """You are now a UI/UX and Web Developer. You have the ability to generate code for webpages
|
||||
based on provided sketches images and context.
|
||||
Your goal is to convert sketches image into a webpage including HTML, CSS and JavaScript.
|
||||
|
||||
NOTE: The image does not have a commercial logo or copyright information. It is just a sketch image of the design.
|
||||
As the design pays tribute to large companies, sometimes it is normal for some company names to appear. Don't worry.
|
||||
|
||||
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
|
||||
|
||||
|
||||
@register_tool(
|
||||
tool_type=ToolType.IMAGE2WEBPAGE.type_name, include_functions=["__init__", "generate_webpages", "save_webpages"]
|
||||
)
|
||||
class GPTvGenerator:
|
||||
"""Class for generating webpages at once.
|
||||
|
||||
This class provides methods to generate webpages including all code (HTML, CSS, and JavaScript) based on an image.
|
||||
It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPTvGenerator class with default values from the configuration."""
|
||||
from metagpt.config2 import config
|
||||
|
||||
self.api_key = config.llm.api_key
|
||||
self.api_base = config.llm.base_url
|
||||
self.model = config.openai_vision_model
|
||||
self.max_tokens = config.vision_max_tokens
|
||||
|
||||
def analyze_layout(self, image_path):
|
||||
"""Analyze the layout of the given image and return the result.
|
||||
|
||||
This is a helper method to generate a layout description based on the image.
|
||||
|
||||
Args:
|
||||
image_path (str): Path of the image to analyze.
|
||||
|
||||
Returns:
|
||||
str: The layout analysis result.
|
||||
"""
|
||||
return self.get_result(image_path, ANALYZE_LAYOUT_PROMPT)
|
||||
|
||||
def generate_webpages(self, image_path):
|
||||
"""Generate webpages including all code (HTML, CSS, and JavaScript) in one go based on the image.
|
||||
|
||||
Args:
|
||||
image_path (str): The path of the image file.
|
||||
|
||||
Returns:
|
||||
str: Generated webpages content.
|
||||
"""
|
||||
layout = self.analyze_layout(image_path)
|
||||
prompt = GENERATE_PROMPT + "\n\n # Context\n The layout information of the sketch image is: \n" + layout
|
||||
result = self.get_result(image_path, prompt)
|
||||
return result
|
||||
|
||||
def get_result(self, image_path, prompt):
|
||||
"""Get the result from the vision model based on the given image path and prompt.
|
||||
|
||||
Args:
|
||||
image_path (str): Path of the image to analyze.
|
||||
prompt (str): Prompt to use for the analysis.
|
||||
|
||||
Returns:
|
||||
str: The model's response as a string.
|
||||
"""
|
||||
base64_image = self.encode_image(image_path)
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Request failed with status {response.status_code}, {response.text}")
|
||||
else:
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
@staticmethod
|
||||
def encode_image(image_path):
|
||||
"""Encode the image at the given path to a base64 string.
|
||||
|
||||
Args:
|
||||
image_path (str): Path of the image to encode.
|
||||
|
||||
Returns:
|
||||
str: The base64 encoded string of the image.
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def save_webpages(image_path, webpages) -> Path:
|
||||
"""Save webpages including all code (HTML, CSS, and JavaScript) at once.
|
||||
|
||||
Args:
|
||||
image_path (str): The path of the image file.
|
||||
webpages (str): The generated webpages content.
|
||||
|
||||
Returns:
|
||||
Path: The path of the saved webpages.
|
||||
"""
|
||||
# 在workspace目录下,创建一个名为下webpages的文件夹,用于存储html、css和js文件
|
||||
webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / Path(image_path).stem
|
||||
os.makedirs(webpages_path, exist_ok=True)
|
||||
|
||||
index_path = webpages_path / "index.html"
|
||||
|
||||
try:
|
||||
index = webpages.split("```html")[1].split("```")[0]
|
||||
except IndexError:
|
||||
index = "No html code found in the result, please check your image and try again." + "\n" + webpages
|
||||
|
||||
try:
|
||||
if "styles.css" in index:
|
||||
style_path = webpages_path / "styles.css"
|
||||
elif "style.css" in index:
|
||||
style_path = webpages_path / "style.css"
|
||||
else:
|
||||
style_path = None
|
||||
style = webpages.split("```css")[1].split("```")[0] if style_path else ""
|
||||
|
||||
if "scripts.js" in index:
|
||||
js_path = webpages_path / "scripts.js"
|
||||
elif "script.js" in index:
|
||||
js_path = webpages_path / "script.js"
|
||||
else:
|
||||
js_path = None
|
||||
js = webpages.split("```javascript")[1].split("```")[0] if js_path else ""
|
||||
except IndexError:
|
||||
raise ValueError("No css or js code found in the result, please check your image and try again.")
|
||||
|
||||
try:
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
f.write(index)
|
||||
if style_path:
|
||||
with open(style_path, "w", encoding="utf-8") as f:
|
||||
f.write(style)
|
||||
if js_path:
|
||||
with open(js_path, "w", encoding="utf-8") as f:
|
||||
f.write(js)
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(f"Cannot save the webpages to {str(webpages_path)}") from e
|
||||
|
||||
return webpages_path
|
||||
183
metagpt/tools/libs/sd_engine.py
Normal file
183
metagpt/tools/libs/sd_engine.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
#
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 7,
|
||||
"width": 512,
|
||||
"height": 768,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
@register_tool(
|
||||
tool_type=ToolType.STABLE_DIFFUSION.type_name,
|
||||
include_functions=["__init__", "simple_run_t2i", "run_t2i", "construct_payload", "save"],
|
||||
)
|
||||
class SDEngine:
|
||||
"""Generate image using stable diffusion model.
|
||||
|
||||
This class provides methods to interact with a stable diffusion service to generate images based on text inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, sd_url=""):
|
||||
"""Initialize the SDEngine instance with configuration.
|
||||
|
||||
Args:
|
||||
sd_url (str, optional): URL of the stable diffusion service. Defaults to "".
|
||||
"""
|
||||
self.sd_url = sd_url
|
||||
self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
"""Modify and set the API parameters for image generation.
|
||||
|
||||
Args:
|
||||
prompt (str): Text input for image generation.
|
||||
negtive_prompt (str, optional): Text input for negative prompts. Defaults to None.
|
||||
width (int, optional): Width of the generated image in pixels. Defaults to 512.
|
||||
height (int, optional): Height of the generated image in pixels. Defaults to 512.
|
||||
sd_model (str, optional): The model to use for image generation. Defaults to "galaxytimemachinesGTM_photoV20".
|
||||
|
||||
Returns:
|
||||
dict: Updated parameters for the stable diffusion API.
|
||||
"""
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negative_prompt"] = negtive_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
def save(self, imgs, save_name=""):
|
||||
"""Save generated images to the output directory.
|
||||
|
||||
Args:
|
||||
imgs (str): Generated images.
|
||||
save_name (str, optional): Output image name. Default is empty.
|
||||
"""
|
||||
save_dir = SOURCE_ROOT / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
|
||||
"""Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images.
|
||||
|
||||
Args:
|
||||
payload (dict): Dictionary of input parameters for the stable diffusion API.
|
||||
auto_save (bool, optional): Save generated images automatically. Defaults to True.
|
||||
|
||||
Returns:
|
||||
list: The generated images as a result of the API call.
|
||||
"""
|
||||
with requests.Session() as session:
|
||||
logger.debug(self.sd_t2i_url)
|
||||
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
|
||||
|
||||
results = rsp.json()["images"]
|
||||
if auto_save:
|
||||
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
|
||||
self.save(results, save_name=f"output_{save_name}")
|
||||
return results
|
||||
|
||||
async def run_t2i(self, payloads: list):
|
||||
"""Run the stable diffusion API for multiple prompts asynchronously.
|
||||
|
||||
Args:
|
||||
payloads (list): list of payload, each payload is a dictionary of input parameters for the stable diffusion API.
|
||||
"""
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(payloads):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self.save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
"""Perform the HTTP POST request to the SD API.
|
||||
|
||||
Args:
|
||||
url (str): The API URL.
|
||||
payload (dict): The payload for the request.
|
||||
session (ClientSession): The session for making HTTP requests.
|
||||
|
||||
Returns:
|
||||
list: Images generated by the stable diffusion API.
|
||||
"""
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
logger.info(save_name)
|
||||
image.save(f"{save_name}.png", pnginfo=pnginfo)
|
||||
return pnginfo, image
|
||||
|
||||
|
||||
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
||||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
22
metagpt/tools/libs/web_scraping.py
Normal file
22
metagpt/tools/libs/web_scraping.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolType.WEBSCRAPING.type_name)
|
||||
async def scrape_web_playwright(url, *urls):
|
||||
"""
|
||||
Scrape and save the HTML structure and inner text content of a web page using Playwright.
|
||||
|
||||
Args:
|
||||
url (str): The main URL to fetch inner text from.
|
||||
*urls (str): Additional URLs to fetch inner text from.
|
||||
|
||||
Returns:
|
||||
(dict): The inner text content and html structure of the web page, key are : 'inner_text', 'html'.
|
||||
"""
|
||||
# Create a PlaywrightWrapper instance for the Chromium browser
|
||||
web = await PlaywrightWrapper().run(url, *urls)
|
||||
|
||||
# Return the inner text content of the web page
|
||||
return {"inner_text": web.inner_text.strip(), "html": web.html.strip()}
|
||||
81
metagpt/tools/tool_convert.py
Normal file
81
metagpt/tools/tool_convert.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import inspect
|
||||
|
||||
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
|
||||
|
||||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = []):
|
||||
docstring = inspect.getdoc(obj)
|
||||
assert docstring, "no docstring found for the objects, skip registering"
|
||||
|
||||
if inspect.isclass(obj):
|
||||
schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}}
|
||||
for name, method in inspect.getmembers(obj, inspect.isfunction):
|
||||
if include and name not in include:
|
||||
continue
|
||||
# method_doc = inspect.getdoc(method)
|
||||
method_doc = get_class_method_docstring(obj, name)
|
||||
if method_doc:
|
||||
schema["methods"][name] = docstring_to_schema(method_doc)
|
||||
|
||||
elif inspect.isfunction(obj):
|
||||
schema = {
|
||||
"type": "function",
|
||||
**docstring_to_schema(docstring),
|
||||
}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def docstring_to_schema(docstring: str):
|
||||
if docstring is None:
|
||||
return {}
|
||||
|
||||
parser = GoogleDocstringParser(docstring=docstring)
|
||||
|
||||
# 匹配简介部分
|
||||
description = parser.parse_desc()
|
||||
|
||||
# 匹配Args部分
|
||||
params = parser.parse_params()
|
||||
parameter_schema = {"properties": {}, "required": []}
|
||||
for param in params:
|
||||
param_name, param_type, param_desc = param
|
||||
# check required or optional
|
||||
is_optional, param_type = parser.check_and_parse_optional(param_type)
|
||||
if not is_optional:
|
||||
parameter_schema["required"].append(param_name)
|
||||
# type and desc
|
||||
param_dict = {"type": param_type, "description": remove_spaces(param_desc)}
|
||||
# match Default for optional args
|
||||
has_default_val, default_val = parser.check_and_parse_default_value(param_desc)
|
||||
if has_default_val:
|
||||
param_dict["default"] = default_val
|
||||
# match Enum
|
||||
has_enum, enum_vals = parser.check_and_parse_enum(param_desc)
|
||||
if has_enum:
|
||||
param_dict["enum"] = enum_vals
|
||||
# add to parameter schema
|
||||
parameter_schema["properties"].update({param_name: param_dict})
|
||||
|
||||
# 匹配Returns部分
|
||||
returns = parser.parse_returns()
|
||||
|
||||
# 构建YAML字典
|
||||
schema = {
|
||||
"description": description,
|
||||
"parameters": parameter_schema,
|
||||
}
|
||||
if returns:
|
||||
schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_class_method_docstring(cls, method_name):
|
||||
"""Retrieve a method's docstring, searching the class hierarchy if necessary."""
|
||||
for base_class in cls.__mro__:
|
||||
if method_name in base_class.__dict__:
|
||||
method = base_class.__dict__[method_name]
|
||||
if method.__doc__:
|
||||
return method.__doc__
|
||||
return None # No docstring found in the class hierarchy
|
||||
18
metagpt/tools/tool_data_type.py
Normal file
18
metagpt/tools/tool_data_type.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolTypeDef(BaseModel):
|
||||
name: str
|
||||
desc: str = ""
|
||||
usage_prompt: str = ""
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
description: str
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
schemas: dict = {}
|
||||
code: str = ""
|
||||
155
metagpt/tools/tool_registry.py
Normal file
155
metagpt/tools/tool_registry.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/01/12 17:07
|
||||
@Author : garylin2099
|
||||
@File : tool_registry.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolTypeDef
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
class ToolRegistry(BaseModel):
|
||||
tools: dict = {}
|
||||
tool_types: dict = {}
|
||||
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
@field_validator("tool_types", mode="before")
|
||||
@classmethod
|
||||
def init_tool_types(cls, tool_types: ToolType):
|
||||
return {tool_type.type_name: tool_type.value for tool_type in tool_types}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
tool_name,
|
||||
tool_path,
|
||||
schema_path="",
|
||||
tool_code="",
|
||||
tool_type="other",
|
||||
tool_source_object=None,
|
||||
include_functions=[],
|
||||
verbose=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
if tool_type not in self.tool_types:
|
||||
# register new tool type on the fly
|
||||
logger.warning(
|
||||
f"{tool_type} not previously defined, will create a temporary tool type with just a name. This tool type is only effective during this runtime. You may consider add this tool type with more configs permanently at metagpt.tools.tool_type"
|
||||
)
|
||||
temp_tool_type_obj = ToolTypeDef(name=tool_type)
|
||||
self.tool_types[tool_type] = temp_tool_type_obj
|
||||
if verbose:
|
||||
logger.info(f"tool type {tool_type} registered")
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
|
||||
if not schemas:
|
||||
return
|
||||
|
||||
schemas["tool_path"] = tool_path # corresponding code file path of the tool
|
||||
try:
|
||||
ToolSchema(**schemas) # validation
|
||||
except Exception:
|
||||
pass
|
||||
# logger.warning(
|
||||
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
|
||||
# )
|
||||
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
if verbose:
|
||||
logger.info(f"{tool_name} registered")
|
||||
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
|
||||
|
||||
def has_tool(self, key: str) -> Tool:
|
||||
return key in self.tools
|
||||
|
||||
def get_tool(self, key) -> Tool:
|
||||
return self.tools.get(key)
|
||||
|
||||
def get_tools_by_type(self, key) -> dict[str, Tool]:
|
||||
return self.tools_by_types.get(key, {})
|
||||
|
||||
def has_tool_type(self, key) -> bool:
|
||||
return key in self.tool_types
|
||||
|
||||
def get_tool_type(self, key) -> ToolType:
|
||||
return self.tool_types.get(key)
|
||||
|
||||
def get_tool_types(self) -> dict[str, ToolType]:
|
||||
return self.tool_types
|
||||
|
||||
|
||||
# Registry instance
|
||||
TOOL_REGISTRY = ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
# Get the file path where the function / class is defined and the source code
|
||||
file_path = inspect.getfile(cls)
|
||||
if "metagpt" in file_path:
|
||||
file_path = re.search("metagpt.+", file_path).group(0)
|
||||
source_code = inspect.getsource(cls)
|
||||
|
||||
TOOL_REGISTRY.register_tool(
|
||||
tool_name=cls.__name__,
|
||||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
tool_type=tool_type,
|
||||
tool_source_object=cls,
|
||||
**kwargs,
|
||||
)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def make_schema(tool_source_object, include, path):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # Create the necessary directories
|
||||
try:
|
||||
schema = convert_code_to_tool_schema(tool_source_object, include=include)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(schema, f, sort_keys=False)
|
||||
# import json
|
||||
# with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f:
|
||||
# json.dump(schema, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
schema = {}
|
||||
logger.error(f"Fail to make schema: {e}")
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]:
|
||||
valid_tools = []
|
||||
for tool_name in tools:
|
||||
if not TOOL_REGISTRY.has_tool(tool_name):
|
||||
logger.warning(
|
||||
f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly"
|
||||
)
|
||||
else:
|
||||
valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name
|
||||
valid_tools.append(valid_tool)
|
||||
return valid_tools
|
||||
55
metagpt/tools/tool_type.py
Normal file
55
metagpt/tools/tool_type.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from enum import Enum
|
||||
|
||||
from metagpt.prompts.tool_types import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
IMAGE2WEBPAGE_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
from metagpt.tools.tool_data_type import ToolTypeDef
|
||||
|
||||
|
||||
class ToolType(Enum):
|
||||
EDA = ToolTypeDef(name="eda", desc="For performing exploratory data analysis")
|
||||
DATA_PREPROCESS = ToolTypeDef(
|
||||
name="data_preprocess",
|
||||
desc="Only for changing value inplace.",
|
||||
usage_prompt=DATA_PREPROCESS_PROMPT,
|
||||
)
|
||||
FEATURE_ENGINEERING = ToolTypeDef(
|
||||
name="feature_engineering",
|
||||
desc="Only for creating new columns for input data.",
|
||||
usage_prompt=FEATURE_ENGINEERING_PROMPT,
|
||||
)
|
||||
MODEL_TRAIN = ToolTypeDef(
|
||||
name="model_train",
|
||||
desc="Only for training model.",
|
||||
usage_prompt=MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
MODEL_EVALUATE = ToolTypeDef(
|
||||
name="model_evaluate",
|
||||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
)
|
||||
STABLE_DIFFUSION = ToolTypeDef(
|
||||
name="stable_diffusion",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
)
|
||||
IMAGE2WEBPAGE = ToolTypeDef(
|
||||
name="image2webpage",
|
||||
desc="For converting image into webpage code.",
|
||||
usage_prompt=IMAGE2WEBPAGE_PROMPT,
|
||||
)
|
||||
WEBSCRAPING = ToolTypeDef(
|
||||
name="web_scraping",
|
||||
desc="For scraping data from web pages.",
|
||||
)
|
||||
OTHER = ToolTypeDef(name="other", desc="Any tools not in the defined categories")
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OTHER
|
||||
|
||||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
|
@ -361,6 +361,31 @@ def parse_recipient(text):
|
|||
return ""
|
||||
|
||||
|
||||
def create_func_call_config(func_schema: dict) -> dict:
|
||||
"""Create new function call config"""
|
||||
tools = [{"type": "function", "function": func_schema}]
|
||||
tool_choice = {"type": "function", "function": {"name": func_schema["name"]}}
|
||||
return {
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def remove_comments(code_str: str) -> str:
|
||||
"""Remove comments from code."""
|
||||
pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"
|
||||
|
||||
def replace_func(match):
|
||||
if match.group(2) is not None:
|
||||
return ""
|
||||
else:
|
||||
return match.group(1)
|
||||
|
||||
clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
|
||||
clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()])
|
||||
return clean_code
|
||||
|
||||
|
||||
def get_class_name(cls) -> str:
|
||||
"""Return class name"""
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
|
|
@ -469,13 +494,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
|||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding=None):
|
||||
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
|
||||
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
|
||||
|
||||
|
||||
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
||||
|
|
|
|||
87
metagpt/utils/parse_docstring.py
Normal file
87
metagpt/utils/parse_docstring.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def remove_spaces(text):
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
class DocstringParser(BaseModel):
|
||||
docstring: str
|
||||
|
||||
def parse_desc(self) -> str:
|
||||
"""Parse and return the description from the docstring."""
|
||||
|
||||
def parse_params(self) -> list[Tuple[str, str, str]]:
|
||||
"""Parse and return the parameters from the docstring.
|
||||
|
||||
Returns:
|
||||
list[Tuple[str, str, str]]: A list of input paramter info. Each info is a triple of (param name, param type, param description)
|
||||
"""
|
||||
|
||||
def parse_returns(self) -> list[Tuple[str, str]]:
|
||||
"""Parse and return the output information from the docstring.
|
||||
|
||||
Returns:
|
||||
list[Tuple[str, str]]: A list of output info. Each info is a tuple of (return type, return description)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter is optional and return a processed param_type rid of the optionality info if so"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter has a default value and return the default value if so"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter description includes an enum and return enum values if so"""
|
||||
|
||||
|
||||
class reSTDocstringParser(DocstringParser):
|
||||
"""A parser for reStructuredText (reST) docstring"""
|
||||
|
||||
|
||||
class GoogleDocstringParser(DocstringParser):
|
||||
"""A parser for Google-stype docstring"""
|
||||
|
||||
docstring: str
|
||||
|
||||
def parse_desc(self) -> str:
|
||||
description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", self.docstring, re.DOTALL)
|
||||
description = remove_spaces(description_match.group(1)) if description_match else ""
|
||||
return description
|
||||
|
||||
def parse_params(self) -> list[Tuple[str, str, str]]:
|
||||
args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", self.docstring, re.DOTALL)
|
||||
_args = args_match.group(1).strip() if args_match else ""
|
||||
# variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)")
|
||||
variable_pattern = re.compile(
|
||||
r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL
|
||||
) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z).
|
||||
params = variable_pattern.findall(_args)
|
||||
return params
|
||||
|
||||
def parse_returns(self) -> list[Tuple[str, str]]:
|
||||
returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", self.docstring, re.DOTALL)
|
||||
returns = returns_match.group(1).strip() if returns_match else ""
|
||||
return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$")
|
||||
returns = return_pattern.findall(returns)
|
||||
return returns
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
|
||||
return "optional" in param_type, param_type.replace(", optional", "")
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
|
||||
default_val = re.search(r"Defaults to (.+?)\.", param_desc)
|
||||
return (True, default_val.group(1)) if default_val else (False, "")
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
|
||||
enum_val = re.search(r"Enum: \[(.+?)\]", param_desc)
|
||||
return (True, [e.strip() for e in enum_val.group(1).split(",")]) if enum_val else (False, [])
|
||||
58
metagpt/utils/recovery_util.py
Normal file
58
metagpt/utils/recovery_util.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/20/2023 11:07 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nbformat
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import read_json_file
|
||||
from metagpt.utils.save_code import save_code_file
|
||||
|
||||
|
||||
def load_history(save_dir: str = ""):
|
||||
"""
|
||||
Load plan and code execution history from the specified save directory.
|
||||
|
||||
Args:
|
||||
save_dir (str): The directory from which to load the history.
|
||||
|
||||
Returns:
|
||||
Tuple: A tuple containing the loaded plan and notebook.
|
||||
"""
|
||||
|
||||
plan_path = Path(save_dir) / "plan.json"
|
||||
nb_path = Path(save_dir) / "history_nb" / "code.ipynb"
|
||||
plan = read_json_file(plan_path)
|
||||
nb = nbformat.read(open(nb_path, "r", encoding="utf-8"), as_version=nbformat.NO_CONVERT)
|
||||
return plan, nb
|
||||
|
||||
|
||||
def save_history(role: Role, save_dir: str = ""):
|
||||
"""
|
||||
Save plan and code execution history to the specified directory.
|
||||
|
||||
Args:
|
||||
role (Role): The role containing the plan and execute_code attributes.
|
||||
save_dir (str): The directory to save the history.
|
||||
|
||||
Returns:
|
||||
Path: The path to the saved history directory.
|
||||
"""
|
||||
record_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
save_path = DATA_PATH / "output" / f"{record_time}"
|
||||
|
||||
# overwrite exist trajectory
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plan = role.planner.plan.dict()
|
||||
|
||||
with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file:
|
||||
json.dump(plan, plan_file, indent=4, ensure_ascii=False)
|
||||
|
||||
save_code_file(name=Path(record_time) / "history_nb", code_context=role.execute_code.nb, file_format="ipynb")
|
||||
return save_path
|
||||
40
metagpt/utils/save_code.py
Normal file
40
metagpt/utils/save_code.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/12/2023 4:14 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
|
||||
import nbformat
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
|
||||
def save_code_file(name: str, code_context: str, file_format: str = "py") -> None:
|
||||
"""
|
||||
Save code files to a specified path.
|
||||
|
||||
Args:
|
||||
- name (str): The name of the folder to save the files.
|
||||
- code_context (str): The code content.
|
||||
- file_format (str, optional): The file format. Supports 'py' (Python file), 'json' (JSON file), and 'ipynb' (Jupyter Notebook file). Default is 'py'.
|
||||
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
# Create the folder path if it doesn't exist
|
||||
os.makedirs(name=DATA_PATH / "output" / f"{name}", exist_ok=True)
|
||||
|
||||
# Choose to save as a Python file or a JSON file based on the file format
|
||||
file_path = DATA_PATH / "output" / f"{name}/code.{file_format}"
|
||||
if file_format == "py":
|
||||
file_path.write_text(code_context + "\n\n", encoding="utf-8")
|
||||
elif file_format == "json":
|
||||
# Parse the code content as JSON and save
|
||||
data = {"code": code_context}
|
||||
write_json_file(file_path, data, encoding="utf-8", indent=2)
|
||||
elif file_format == "ipynb":
|
||||
nbformat.write(code_context, file_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file format. Please choose 'py', 'json', or 'ipynb'.")
|
||||
|
|
@ -35,7 +35,6 @@ tqdm==4.65.0
|
|||
# webdriver_manager<3.9
|
||||
anthropic==0.8.1
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.9.0
|
||||
libcst==1.0.1
|
||||
qdrant-client==1.7.0
|
||||
# pytest-mock==3.11.1 # test extras require
|
||||
|
|
@ -51,6 +50,13 @@ websocket-client==1.6.2
|
|||
aiofiles==23.2.1
|
||||
gitpython==3.1.40
|
||||
zhipuai==2.0.1
|
||||
rich==13.6.0
|
||||
nbclient==0.9.0
|
||||
nbformat==5.9.2
|
||||
ipython==8.17.2
|
||||
ipykernel==6.27.0
|
||||
scikit_learn==1.3.2
|
||||
typing-extensions==4.9.0
|
||||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
|
|
@ -59,4 +65,5 @@ networkx~=3.2.1
|
|||
google-generativeai==0.3.2
|
||||
# playwright==1.40.0 # playwright extras require
|
||||
anytree
|
||||
Pillow
|
||||
ipywidgets==8.1.1
|
||||
Pillow
|
||||
|
|
|
|||
|
|
@ -38,14 +38,14 @@ def rsp_cache():
|
|||
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
|
||||
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
with open(new_rsp_cache_file_path, "w") as f2:
|
||||
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
@ -64,6 +64,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
llm.rsp_cache = rsp_cache
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", llm.aask)
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask_batch", llm.aask_batch)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM.aask_code", llm.aask_code)
|
||||
yield mocker
|
||||
if hasattr(request.node, "test_outcome") and request.node.test_outcome.passed:
|
||||
if llm.rsp_candidates:
|
||||
|
|
@ -71,7 +72,7 @@ def llm_mock(rsp_cache, mocker, request):
|
|||
cand_key = list(rsp_candidate.keys())[0]
|
||||
cand_value = list(rsp_candidate.values())[0]
|
||||
if cand_key not in llm.rsp_cache:
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {cand_value[:20]} ...' to response cache")
|
||||
logger.info(f"Added '{cand_key[:100]} ... -> {str(cand_value)[:20]} ...' to response cache")
|
||||
llm.rsp_cache.update(rsp_candidate)
|
||||
RSP_CACHE_NEW.update(rsp_candidate)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
12
tests/metagpt/actions/ci/test_ask_review.py
Normal file
12
tests/metagpt/actions/ci/test_ask_review.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ask_review import AskReview
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_review(mocker):
|
||||
mock_review_input = "confirm"
|
||||
mocker.patch("builtins.input", return_value=mock_review_input)
|
||||
rsp, confirmed = await AskReview().run()
|
||||
assert rsp == mock_review_input
|
||||
assert confirmed
|
||||
51
tests/metagpt/actions/ci/test_debug_code.py
Normal file
51
tests/metagpt/actions/ci/test_debug_code.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.debug_code import DebugCode
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code["code"]
|
||||
116
tests/metagpt/actions/ci/test_execute_nb_code.py
Normal file
116
tests/metagpt/actions/ci/test_execute_nb_code.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode, truncate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_running():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("print('hello world!')")
|
||||
assert is_success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_split_code_running():
|
||||
executor = ExecuteNbCode()
|
||||
_ = await executor.run("x=1\ny=2")
|
||||
_ = await executor.run("z=x+y")
|
||||
output, is_success = await executor.run("assert z==3")
|
||||
assert is_success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_error():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run("z=1/0")
|
||||
assert not is_success
|
||||
|
||||
|
||||
PLOT_CODE = """
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 生成随机数据
|
||||
random_data = np.random.randn(1000) # 生成1000个符合标准正态分布的随机数
|
||||
|
||||
# 绘制直方图
|
||||
plt.hist(random_data, bins=30, density=True, alpha=0.7, color='blue', edgecolor='black')
|
||||
|
||||
# 添加标题和标签
|
||||
plt.title('Histogram of Random Data')
|
||||
plt.xlabel('Value')
|
||||
plt.ylabel('Frequency')
|
||||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
plt.close()
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plotting_code():
|
||||
executor = ExecuteNbCode()
|
||||
output, is_success = await executor.run(PLOT_CODE)
|
||||
assert is_success
|
||||
|
||||
|
||||
def test_truncate():
|
||||
# 代码执行成功
|
||||
output, is_success = truncate("hello world", 5, True)
|
||||
assert "Truncated to show only first 5 characters\nhello" in output
|
||||
assert is_success
|
||||
# 代码执行失败
|
||||
output, is_success = truncate("hello world", 5, False)
|
||||
assert "Truncated to show only last 5 characters\nworld" in output
|
||||
assert not is_success
|
||||
# 异步
|
||||
output, is_success = truncate("<coroutine object", 5, True)
|
||||
assert not is_success
|
||||
assert "await" in output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_timeout():
|
||||
executor = ExecuteNbCode(timeout=1)
|
||||
code = "import time; time.sleep(2)"
|
||||
message, success = await executor.run(code)
|
||||
assert not success
|
||||
assert message.startswith("Cell execution timed out")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_code_text():
|
||||
executor = ExecuteNbCode()
|
||||
message, success = await executor.run(code='print("This is a code!")', language="python")
|
||||
assert success
|
||||
assert message == "This is a code!\n"
|
||||
message, success = await executor.run(code="# This is a code!", language="markdown")
|
||||
assert success
|
||||
assert message == "# This is a code!"
|
||||
mix_text = "# Title!\n ```python\n print('This is a code!')```"
|
||||
message, success = await executor.run(code=mix_text, language="markdown")
|
||||
assert success
|
||||
assert message == mix_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.terminate()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
assert executor.nb_client.km is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset():
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await executor.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await executor.reset()
|
||||
assert executor.nb_client.km is None
|
||||
46
tests/metagpt/actions/ci/test_ml_action.py
Normal file
46
tests/metagpt/actions/ci/test_ml_action.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.ml_action import WriteCodeWithToolsML
|
||||
from metagpt.schema import Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
_, code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
code_with_ml = code_with_ml["code"]
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
324
tests/metagpt/actions/ci/test_write_analysis_code.py
Normal file
324
tests/metagpt/actions/ci/test_write_analysis_code.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.ci.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.strategy.planner import STRUCTURAL_CONTEXT
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_by_list_plan():
|
||||
write_code = WriteCodeWithoutTools()
|
||||
execute_code = ExecuteNbCode()
|
||||
messages = []
|
||||
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "回顾已完成的任务", "求均值", "总结"]
|
||||
for task in plan:
|
||||
print(f"\n任务: {task}\n\n")
|
||||
messages.append(Message(task, role="assistant"))
|
||||
code = await write_code.run(messages)
|
||||
if task.startswith(("回顾", "总结")):
|
||||
assert code["language"] == "markdown"
|
||||
else:
|
||||
assert code["language"] == "python"
|
||||
messages.append(Message(code["code"], role="assistant"))
|
||||
assert len(code) > 0
|
||||
output, _ = await execute_code.run(**code)
|
||||
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
|
||||
messages.append(output)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_recommendation():
|
||||
task = "clean and preprocess the data"
|
||||
available_tools = {
|
||||
"FillMissingValue": "Filling missing values",
|
||||
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
}
|
||||
write_code = WriteCodeWithTools()
|
||||
tools = await write_code._recommend_tool(task, available_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert "FillMissingValue" in tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
context=plan.context,
|
||||
tasks=list(task_map.values()),
|
||||
current_task=plan.current_task.model_dump_json(),
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
code = code["code"]
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
|
||||
error = """
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 2, in <module>
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
|
||||
io = ExcelFile(io, storage_options=storage_options, engine=engine)
|
||||
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
|
||||
raise ValueError(
|
||||
ValueError: Excel file format cannot be determined, you must specify an engine manually.
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
Message(content=wrong_code, role="assistant"),
|
||||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteCodeWithoutTools().run(context=context)
|
||||
new_code = new_code["code"]
|
||||
print(new_code)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_simple():
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
read a dataset test.csv and print its head
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "import pandas and load the dataset from 'test.csv'.",
|
||||
"task_type": "",
|
||||
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
|
||||
"result": "",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Print the head of the dataset to display the first few rows.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
code = await WriteCodeWithoutTools().run(context=context)
|
||||
code = code["code"]
|
||||
print(code)
|
||||
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Iris dataset, include a plot
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the Iris dataset from sklearn.",
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
|
||||
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": [
|
||||
"1"
|
||||
],
|
||||
"instruction": "Perform exploratory data analysis on the Iris dataset.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": [
|
||||
"2"
|
||||
],
|
||||
"instruction": "Create a plot visualizing the Iris dataset features.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long_for_wine():
|
||||
"""test code reuse for long context"""
|
||||
|
||||
structural_context = """
|
||||
## User Requirement
|
||||
Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy
|
||||
## Current Plan
|
||||
[
|
||||
{
|
||||
"task_id": "1",
|
||||
"dependent_task_ids": [],
|
||||
"instruction": "Load the sklearn Wine recognition dataset and perform exploratory data analysis."
|
||||
"task_type": "",
|
||||
"code": "from sklearn.datasets import load_wine\n# Load the Wine recognition dataset\nwine_data = load_wine()\n# Perform exploratory data analysis\nwine_data.keys()",
|
||||
"result": "Truncated to show only the last 1000 characters\ndict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names'])",
|
||||
"is_finished": true
|
||||
},
|
||||
{
|
||||
"task_id": "2",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Create a plot to visualize some aspect of the wine dataset."
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "3",
|
||||
"dependent_task_ids": ["1"],
|
||||
"instruction": "Split the dataset into training and validation sets with a 20% validation size.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "4",
|
||||
"dependent_task_ids": ["3"],
|
||||
"instruction": "Train a model on the training set to predict wine class.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
},
|
||||
{
|
||||
"task_id": "5",
|
||||
"dependent_task_ids": ["4"],
|
||||
"instruction": "Evaluate the model on the validation set and report the accuracy.",
|
||||
"task_type": "",
|
||||
"code": "",
|
||||
"result": "",
|
||||
"is_finished": false
|
||||
}
|
||||
]
|
||||
## Current Task
|
||||
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Create a plot to visualize some aspect of the Wine dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
|
||||
"""
|
||||
context = [
|
||||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
trials_num = 5
|
||||
trials = [WriteCodeWithoutTools().run(context=context, temperature=0.0) for _ in range(trials_num)]
|
||||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
34
tests/metagpt/actions/ci/test_write_plan.py
Normal file
34
tests/metagpt/actions/ci/test_write_plan.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.write_plan import (
|
||||
Plan,
|
||||
Task,
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
)
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_precheck_update_plan_from_rsp():
|
||||
plan = Plan(goal="")
|
||||
plan.add_tasks([Task(task_id="1")])
|
||||
rsp = '[{"task_id": "2"}]'
|
||||
success, _ = precheck_update_plan_from_rsp(rsp, plan)
|
||||
assert success
|
||||
assert len(plan.tasks) == 1 and plan.tasks[0].task_id == "1" # precheck should not change the original one
|
||||
|
||||
invalid_rsp = "wrong"
|
||||
success, _ = precheck_update_plan_from_rsp(invalid_rsp, plan)
|
||||
assert not success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("use_tools", [(False), (True)])
|
||||
async def test_write_plan(use_tools):
|
||||
rsp = await WritePlan().run(
|
||||
context=[Message("run analysis on sklearn iris dataset", role="user")], use_tools=use_tools
|
||||
)
|
||||
|
||||
assert "task_id" in rsp
|
||||
assert "instruction" in rsp
|
||||
assert "json" not in rsp # the output should be the content inside ```json ```
|
||||
|
|
@ -1,49 +1,25 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from PIL import Image
|
||||
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = LLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
|
||||
logger.info(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = LLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_message():
|
||||
llm = LLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
llm = LLM()
|
||||
|
|
@ -63,16 +39,41 @@ async def test_speech_to_text():
|
|||
assert "你好" == resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_image():
|
||||
llm = LLM()
|
||||
model = "dall-e-3"
|
||||
prompt = 'a logo with word "MetaGPT"'
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
|
||||
assert images[0].size == (1024, 1024)
|
||||
@pytest.fixture
|
||||
def tool_calls_rsp():
|
||||
function_rsps = [
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
]
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) for i, f in enumerate(function_rsps)
|
||||
]
|
||||
messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls]
|
||||
# 添加一个纯文本响应
|
||||
messages.append(
|
||||
ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None)
|
||||
)
|
||||
# 添加 openai tool calls respond bug, code 出现在ChatCompletionMessage.content中
|
||||
messages.extend(
|
||||
[
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages)
|
||||
]
|
||||
return [
|
||||
ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion")
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
@pytest.fixture
|
||||
def json_decode_error():
|
||||
function_rsp = Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute")
|
||||
tool_calls = [ChatCompletionMessageToolCall(type="function", id=f"call_{0}", function=function_rsp)]
|
||||
message = ChatCompletionMessage(content=None, role="assistant", tool_calls=tool_calls)
|
||||
choices = [Choice(finish_reason="tool_calls", logprobs=None, index=0, message=message)]
|
||||
return ChatCompletion(id="0", choices=choices, created=0, model="gpt-4", object="chat.completion")
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
|
|
@ -87,3 +88,36 @@ class TestOpenAI:
|
|||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp):
|
||||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
for i, rsp in enumerate(tool_calls_rsp):
|
||||
code = instance.get_choice_function_arguments(rsp)
|
||||
logger.info(f"\ntest get function call arguments {i}: {code}")
|
||||
assert "code" in code
|
||||
assert "language" in code
|
||||
assert "hello world" in code["code"]
|
||||
logger.info(f'code is : {code["code"]}')
|
||||
|
||||
if "Completed a python code for hello world!" == code["code"]:
|
||||
code["language"] == "markdown"
|
||||
else:
|
||||
code["language"] == "python"
|
||||
|
||||
def test_aask_code_json_decode_error(self, json_decode_error):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
with pytest.raises(json.decoder.JSONDecodeError) as e:
|
||||
instance.get_choice_function_arguments(json_decode_error)
|
||||
assert "JSONDecodeError" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_image():
|
||||
llm = LLM()
|
||||
model = "dall-e-3"
|
||||
prompt = 'a logo with word "MetaGPT"'
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
|
|
|||
19
tests/metagpt/roles/ci/test_code_interpreter.py
Normal file
19
tests/metagpt/roles/ci/test_code_interpreter.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("auto_run", [(True), (False)])
|
||||
async def test_code_interpreter(mocker, auto_run):
|
||||
mocker.patch("metagpt.actions.ci.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
tools = []
|
||||
|
||||
ci = CodeInterpreter(auto_run=auto_run, use_tools=True, tools=tools)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
90
tests/metagpt/roles/ci/test_ml_engineer.py
Normal file
90
tests/metagpt/roles/ci/test_ml_engineer.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ci.ml_engineer import MLEngineer
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from tests.metagpt.actions.ci.test_debug_code import CODE, DebugContext, ErrorStr
|
||||
|
||||
|
||||
def test_mle_init():
|
||||
ci = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
|
||||
assert ci.tools == []
|
||||
|
||||
|
||||
MockPlan = Plan(
|
||||
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
|
||||
context="",
|
||||
tasks=[
|
||||
Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
],
|
||||
task_map={
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
dependent_task_ids=[],
|
||||
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
|
||||
task_type="eda",
|
||||
code="",
|
||||
result="",
|
||||
is_success=False,
|
||||
is_finished=False,
|
||||
)
|
||||
},
|
||||
current_task_id="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_write_code(mocker):
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
code, _ = await mle._write_code()
|
||||
assert data_path in code["code"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_update_data_columns(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.planner.plan = MockPlan
|
||||
|
||||
# manually update task type to test update
|
||||
mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value
|
||||
|
||||
result = await mle._update_data_columns()
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mle_debug_code(mocker):
|
||||
mle = MLEngineer(auto_run=True, use_tools=True)
|
||||
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecuteNbCode))
|
||||
mle.latest_code = CODE
|
||||
mle.debug_context = DebugContext
|
||||
code, _ = await mle._write_code()
|
||||
assert len(code) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_engineer():
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
|
||||
|
||||
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await mle.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -25,7 +25,9 @@ from metagpt.schema import (
|
|||
Document,
|
||||
Message,
|
||||
MessageQueue,
|
||||
Plan,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UserMessage,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
|
@ -181,5 +183,173 @@ def test_class_view():
|
|||
)
|
||||
|
||||
|
||||
class TestPlan:
|
||||
def test_add_tasks_ordering(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["2", "3", "1"]
|
||||
|
||||
def test_add_tasks_to_existing_no_common_prefix(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second", is_finished=True),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
|
||||
new_tasks = [Task(task_id="3", instruction="")]
|
||||
plan.add_tasks(new_tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["3"]
|
||||
assert not plan.tasks[0].is_finished # must be the new unfinished task
|
||||
|
||||
def test_add_tasks_to_existing_with_common_prefix(self):
|
||||
plan = Plan(goal="")
|
||||
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 1
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task() # finish 2
|
||||
plan.finish_current_task() # finish 3
|
||||
|
||||
new_tasks = [
|
||||
Task(task_id="4", dependent_task_ids=["3"], instruction="Third"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
Task(task_id="3", dependent_task_ids=["2"], instruction="Second"),
|
||||
] # 2 -> 3 -> 4, so the common prefix is 2 -> 3, and these two should be obtained from the existing tasks
|
||||
plan.add_tasks(new_tasks)
|
||||
|
||||
assert [task.task_id for task in plan.tasks] == ["2", "3", "4"]
|
||||
assert (
|
||||
plan.tasks[0].is_finished and plan.tasks[1].is_finished
|
||||
) # "2" and "3" should be the original finished one
|
||||
assert plan.current_task_id == "4"
|
||||
|
||||
def test_current_task(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", dependent_task_ids=["2"], instruction="Second"),
|
||||
Task(task_id="2", instruction="First"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
assert plan.current_task.task_id == "2"
|
||||
|
||||
def test_finish_task(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First"),
|
||||
Task(task_id="2", dependent_task_ids=["1"], instruction="Second"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task()
|
||||
assert plan.current_task.task_id == "2"
|
||||
|
||||
def test_finished_tasks(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First"),
|
||||
Task(task_id="2", dependent_task_ids=["1"], instruction="Second"),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
plan.finish_current_task()
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
assert len(finished_tasks) == 1
|
||||
assert finished_tasks[0].task_id == "1"
|
||||
|
||||
def test_reset_task_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("1")
|
||||
reset_task = plan.task_map["1"]
|
||||
assert reset_task.code == ""
|
||||
assert reset_task.result == ""
|
||||
assert not reset_task.is_finished
|
||||
|
||||
def test_reset_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("2") # Task with ID 2 does not exist
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
||||
def test_replace_task_with_dependents(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [
|
||||
Task(task_id="1", instruction="First Task", finished=True),
|
||||
Task(task_id="2", instruction="Second Task", dependent_task_ids=["1"], finished=True),
|
||||
]
|
||||
plan.add_tasks(tasks)
|
||||
new_task = Task(task_id="1", instruction="Updated First Task")
|
||||
plan.replace_task(new_task)
|
||||
assert plan.task_map["1"].instruction == "Updated First Task"
|
||||
assert not plan.task_map["2"].is_finished # Dependent task should be reset
|
||||
assert plan.task_map["2"].code == ""
|
||||
assert plan.task_map["2"].result == ""
|
||||
|
||||
def test_replace_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="First Task")
|
||||
plan.add_tasks([task])
|
||||
new_task = Task(task_id="2", instruction="New Task")
|
||||
with pytest.raises(AssertionError):
|
||||
plan.replace_task(new_task) # Task with ID 2 does not exist in plan
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
||||
def test_append_task_with_valid_dependencies(self):
|
||||
plan = Plan(goal="Test")
|
||||
existing_task = [Task(task_id="1")]
|
||||
plan.add_tasks(existing_task)
|
||||
new_task = Task(task_id="2", dependent_task_ids=["1"])
|
||||
plan.append_task(new_task)
|
||||
assert plan.tasks[-1].task_id == "2"
|
||||
assert plan.task_map["2"] == new_task
|
||||
|
||||
def test_append_task_with_invalid_dependencies(self):
|
||||
new_task = Task(task_id="2", dependent_task_ids=["3"])
|
||||
plan = Plan(goal="Test")
|
||||
with pytest.raises(AssertionError):
|
||||
plan.append_task(new_task)
|
||||
|
||||
def test_append_task_without_dependencies(self):
|
||||
plan = Plan(goal="Test")
|
||||
existing_task = [Task(task_id="1")]
|
||||
plan.add_tasks(existing_task)
|
||||
|
||||
new_task = Task(task_id="2")
|
||||
plan.append_task(new_task)
|
||||
|
||||
assert len(plan.tasks) == 2
|
||||
assert plan.current_task_id == "1"
|
||||
|
||||
def test_append_task_updates_current_task(self):
|
||||
finished_task = Task(task_id="1", is_finished=True)
|
||||
new_task = Task(task_id="2")
|
||||
plan = Plan(goal="Test", tasks=[finished_task])
|
||||
plan.append_task(new_task)
|
||||
assert plan.current_task_id == "2"
|
||||
|
||||
def test_update_current_task(self):
|
||||
task1 = Task(task_id="1", is_finished=True)
|
||||
task2 = Task(task_id="2")
|
||||
plan = Plan(goal="Test", tasks=[task1, task2])
|
||||
plan._update_current_task()
|
||||
assert plan.current_task_id == "2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
6
tests/metagpt/tools/libs/__init__.py
Normal file
6
tests/metagpt/tools/libs/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2024/1/11 16:14
|
||||
# @Author : lidanyang
|
||||
# @File : __init__.py
|
||||
# @Desc :
|
||||
111
tests/metagpt/tools/libs/test_data_preprocess.py
Normal file
111
tests/metagpt/tools/libs/test_data_preprocess.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.libs.data_preprocess import (
|
||||
FillMissingValue,
|
||||
LabelEncode,
|
||||
MaxAbsScale,
|
||||
MinMaxScale,
|
||||
OneHotEncode,
|
||||
OrdinalEncode,
|
||||
RobustScale,
|
||||
StandardScale,
|
||||
get_column_info,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasets():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5],
|
||||
"cat1": ["A", "B", np.nan, "D", "A"],
|
||||
"date1": [
|
||||
datetime(2020, 1, 1),
|
||||
datetime(2020, 1, 2),
|
||||
datetime(2020, 1, 3),
|
||||
datetime(2020, 1, 4),
|
||||
datetime(2020, 1, 5),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_fill_missing_value(mock_datasets):
|
||||
fm = FillMissingValue(features=["num1"], strategy="mean")
|
||||
transformed = fm.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["num1"].isnull().sum() == 0
|
||||
|
||||
|
||||
def test_min_max_scale(mock_datasets):
|
||||
mms = MinMaxScale(features=["num1"])
|
||||
transformed = mms.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].min(), 0)
|
||||
npt.assert_allclose(transformed["num1"].max(), 1)
|
||||
|
||||
|
||||
def test_standard_scale(mock_datasets):
|
||||
ss = StandardScale(features=["num1"])
|
||||
transformed = ss.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].mean()) == 0
|
||||
assert int(transformed["num1"].std()) == 1
|
||||
|
||||
|
||||
def test_max_abs_scale(mock_datasets):
|
||||
mas = MaxAbsScale(features=["num1"])
|
||||
transformed = mas.fit_transform(mock_datasets.copy())
|
||||
|
||||
npt.assert_allclose(transformed["num1"].abs().max(), 1)
|
||||
|
||||
|
||||
def test_robust_scale(mock_datasets):
|
||||
rs = RobustScale(features=["num1"])
|
||||
transformed = rs.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert int(transformed["num1"].median()) == 0
|
||||
|
||||
|
||||
def test_ordinal_encode(mock_datasets):
|
||||
oe = OrdinalEncode(features=["cat1"])
|
||||
transformed = oe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 2
|
||||
|
||||
|
||||
def test_one_hot_encode(mock_datasets):
|
||||
ohe = OneHotEncode(features=["cat1"])
|
||||
transformed = ohe.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1_A"].max() == 1
|
||||
|
||||
|
||||
def test_label_encode(mock_datasets):
|
||||
le = LabelEncode(features=["cat1"])
|
||||
transformed = le.fit_transform(mock_datasets.copy())
|
||||
|
||||
assert transformed["cat1"].max() == 3
|
||||
|
||||
# test transform with unseen data
|
||||
test = mock_datasets.copy()
|
||||
test["cat1"] = ["A", "B", "C", "D", "E"]
|
||||
transformed = le.transform(test)
|
||||
assert transformed["cat1"].max() == 4
|
||||
|
||||
|
||||
def test_get_column_info(mock_datasets):
|
||||
df = mock_datasets
|
||||
column_info = get_column_info(df)
|
||||
|
||||
assert column_info == {
|
||||
"Category": ["cat1"],
|
||||
"Numeric": ["num1"],
|
||||
"Datetime": ["date1"],
|
||||
"Others": [],
|
||||
}
|
||||
175
tests/metagpt/tools/libs/test_feature_engineering.py
Normal file
175
tests/metagpt/tools/libs/test_feature_engineering.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from sklearn.datasets import fetch_california_housing, load_breast_cancer, load_iris
|
||||
|
||||
from metagpt.tools.libs.feature_engineering import (
|
||||
CatCount,
|
||||
CatCross,
|
||||
ExtractTimeComps,
|
||||
GeneralSelection,
|
||||
GroupStat,
|
||||
KFoldTargetMeanEncoder,
|
||||
PolynomialExpansion,
|
||||
SplitBins,
|
||||
TargetMeanEncoder,
|
||||
TreeBasedSelection,
|
||||
VarianceBasedSelection,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset():
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"num1": [1, 2, np.nan, 4, 5, 6, 7, 3],
|
||||
"num2": [1, 3, 2, 1, np.nan, 5, 6, 4],
|
||||
"num3": [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
|
||||
"cat1": ["A", "B", np.nan, "D", "E", "C", "B", "A"],
|
||||
"cat2": ["A", "A", "A", "A", "A", "A", "A", "A"],
|
||||
"date1": [
|
||||
"2020-01-01",
|
||||
"2020-01-02",
|
||||
"2020-01-03",
|
||||
"2020-01-04",
|
||||
"2020-01-05",
|
||||
"2020-01-06",
|
||||
"2020-01-07",
|
||||
"2020-01-08",
|
||||
],
|
||||
"label": [0, 1, 0, 1, 0, 1, 0, 1],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_sklearn_data(data_name):
|
||||
if data_name == "iris":
|
||||
data = load_iris()
|
||||
elif data_name == "breast_cancer":
|
||||
data = load_breast_cancer()
|
||||
elif data_name == "housing":
|
||||
data = fetch_california_housing()
|
||||
else:
|
||||
raise ValueError("data_name not supported")
|
||||
|
||||
X, y, feature_names = data.data, data.target, data.feature_names
|
||||
data = pd.DataFrame(X, columns=feature_names)
|
||||
data["label"] = y
|
||||
return data
|
||||
|
||||
|
||||
def test_polynomial_expansion(mock_dataset):
|
||||
pe = PolynomialExpansion(cols=["num1", "num2", "label"], degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(mock_dataset)
|
||||
|
||||
assert len(transformed.columns) == len(mock_dataset.columns) + 3
|
||||
|
||||
# when too many columns
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
cols = [c for c in data.columns if c != "label"]
|
||||
pe = PolynomialExpansion(cols=cols, degree=2, label_col="label")
|
||||
transformed = pe.fit_transform(data)
|
||||
|
||||
assert len(transformed.columns) == len(data.columns) + 55
|
||||
|
||||
|
||||
def test_cat_count(mock_dataset):
|
||||
cc = CatCount(col="cat1")
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cnt" in transformed.columns
|
||||
assert transformed["cat1_cnt"][0] == 2
|
||||
|
||||
|
||||
def test_target_mean_encoder(mock_dataset):
|
||||
tme = TargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = tme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_target_mean" in transformed.columns
|
||||
assert transformed["cat1_target_mean"][0] == 0.5
|
||||
|
||||
|
||||
def test_kfold_target_mean_encoder(mock_dataset):
|
||||
kfme = KFoldTargetMeanEncoder(col="cat1", label="label")
|
||||
transformed = kfme.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_kf_target_mean" in transformed.columns
|
||||
|
||||
|
||||
def test_cat_cross(mock_dataset):
|
||||
cc = CatCross(cols=["cat1", "cat2"])
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" in transformed.columns
|
||||
|
||||
cc = CatCross(cols=["cat1", "cat2"], max_cat_num=3)
|
||||
transformed = cc.fit_transform(mock_dataset)
|
||||
|
||||
assert "cat1_cat2" not in transformed.columns
|
||||
|
||||
|
||||
def test_group_stat(mock_dataset):
|
||||
gs = GroupStat(group_col="cat1", agg_col="num1", agg_funcs=["mean", "sum"])
|
||||
transformed = gs.fit_transform(mock_dataset)
|
||||
|
||||
assert "num1_mean_by_cat1" in transformed.columns
|
||||
assert "num1_sum_by_cat1" in transformed.columns
|
||||
|
||||
|
||||
def test_split_bins(mock_dataset):
|
||||
sb = SplitBins(cols=["num1"])
|
||||
transformed = sb.fit_transform(mock_dataset)
|
||||
|
||||
assert transformed["num1"].nunique() <= 5
|
||||
assert all(0 <= x < 5 for x in transformed["num1"])
|
||||
|
||||
|
||||
def test_extract_time_comps(mock_dataset):
|
||||
time_comps = ["year", "month", "day", "hour", "dayofweek", "is_weekend"]
|
||||
etc = ExtractTimeComps(time_col="date1", time_comps=time_comps)
|
||||
transformed = etc.fit_transform(mock_dataset.copy())
|
||||
|
||||
for comp in time_comps:
|
||||
assert comp in transformed.columns
|
||||
assert transformed["year"][0] == 2020
|
||||
assert transformed["month"][0] == 1
|
||||
assert transformed["day"][0] == 1
|
||||
assert transformed["hour"][0] == 0
|
||||
assert transformed["dayofweek"][0] == 3
|
||||
assert transformed["is_weekend"][0] == 0
|
||||
|
||||
|
||||
def test_general_selection(mock_dataset):
|
||||
gs = GeneralSelection(label_col="label")
|
||||
transformed = gs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
assert "cat2" not in transformed.columns
|
||||
|
||||
|
||||
@pytest.mark.skip # skip because TreeBasedSelection needs lgb as dependency
|
||||
def test_tree_based_selection(mock_dataset):
|
||||
# regression
|
||||
data = load_sklearn_data("housing")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="reg")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# classification
|
||||
data = load_sklearn_data("breast_cancer")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="cls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
# multi-classification
|
||||
data = load_sklearn_data("iris")
|
||||
tbs = TreeBasedSelection(label_col="label", task_type="mcls")
|
||||
transformed = tbs.fit_transform(data)
|
||||
assert len(transformed.columns) > 1
|
||||
|
||||
|
||||
def test_variance_based_selection(mock_dataset):
|
||||
vbs = VarianceBasedSelection(label_col="label")
|
||||
transformed = vbs.fit_transform(mock_dataset.copy())
|
||||
|
||||
assert "num3" not in transformed.columns
|
||||
40
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
40
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : test_gpt_v_generator.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt import logs
|
||||
from metagpt.tools.libs.gpt_v_generator import GPTvGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_webpages(mocker):
|
||||
mock_data = """```html\n<html>\n<script src="scripts.js"></script>
|
||||
<link rel="stylesheet" href="styles.css(">\n</html>\n```\n
|
||||
```css\n.class { ... }\n```\n
|
||||
```javascript\nfunction() { ... }\n```\n"""
|
||||
mocker.patch("metagpt.tools.libs.gpt_v_generator.GPTvGenerator.generate_webpages", return_value=mock_data)
|
||||
return mocker
|
||||
|
||||
|
||||
def test_vision_generate_webpages(mock_webpages):
|
||||
image_path = "image.png"
|
||||
generator = GPTvGenerator()
|
||||
rsp = generator.generate_webpages(image_path=image_path)
|
||||
logs.logger.info(rsp)
|
||||
assert "html" in rsp
|
||||
assert "css" in rsp
|
||||
assert "javascript" in rsp
|
||||
|
||||
|
||||
def test_save_webpages(mock_webpages):
|
||||
image_path = "image.png"
|
||||
generator = GPTvGenerator()
|
||||
webpages = generator.generate_webpages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
61
tests/metagpt/tools/libs/test_sd_engine.py
Normal file
61
tests/metagpt/tools/libs/test_sd_engine.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from metagpt.tools.libs.sd_engine import SDEngine
|
||||
|
||||
|
||||
def generate_mock_image_data():
|
||||
# 创建一个简单的图片对象
|
||||
image = Image.new("RGB", (100, 100), color="white")
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.text((10, 10), "Mock Image", fill="black")
|
||||
|
||||
# 将图片转换为二进制数据
|
||||
with io.BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
image_binary = buffer.getvalue()
|
||||
|
||||
# 对图片二进制数据进行 base64 编码
|
||||
image_base64 = base64.b64encode(image_binary).decode("utf-8")
|
||||
|
||||
return image_base64
|
||||
|
||||
|
||||
def test_sd_tools(mocker):
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.json.return_value = {"images": [generate_mock_image_data()]}
|
||||
mocker.patch("requests.Session.post", return_value=mock_response)
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i(mocker):
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.read.return_value = json.dumps({"images": [generate_mock_image_data()]})
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
assert "negative_prompt" in engine.payload
|
||||
23
tests/metagpt/tools/libs/test_web_scraping.py
Normal file
23
tests/metagpt/tools/libs/test_web_scraping.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.libs.web_scraping import scrape_web_playwright
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scrape_web_playwright():
|
||||
test_url = "https://www.deepwisdom.ai"
|
||||
|
||||
result = await scrape_web_playwright(test_url)
|
||||
|
||||
# Assert that the result is a dictionary
|
||||
assert isinstance(result, dict)
|
||||
|
||||
# Assert that the result contains 'inner_text' and 'html' keys
|
||||
assert "inner_text" in result
|
||||
assert "html" in result
|
||||
|
||||
# Assert startswith and endswith
|
||||
assert not result["inner_text"].startswith(" ")
|
||||
assert not result["inner_text"].endswith(" ")
|
||||
assert not result["html"].startswith(" ")
|
||||
assert not result["html"].endswith(" ")
|
||||
154
tests/metagpt/tools/test_tool_convert.py
Normal file
154
tests/metagpt/tools/test_tool_convert.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
|
||||
|
||||
class DummyClass:
|
||||
"""
|
||||
Completing missing values with simple strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list, strategy: str = "mean", fill_value=None):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
|
||||
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the FillMissingValue model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
Each key corresponds to a list of column names belonging to that category.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_class():
|
||||
expected = {
|
||||
"type": "class",
|
||||
"description": "Completing missing values with simple strategies.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
},
|
||||
"fit": {
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
},
|
||||
"transform": {
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
},
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(DummyClass)
|
||||
assert schema == expected
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
102
tests/metagpt/tools/test_tool_registry.py
Normal file
102
tests/metagpt/tools/test_tool_registry.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry():
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
"""test class"""
|
||||
|
||||
def test_class_fn(self):
|
||||
"""test class fn"""
|
||||
pass
|
||||
|
||||
|
||||
def test_fn():
|
||||
"""test function"""
|
||||
pass
|
||||
|
||||
|
||||
# Test Tool Registration Class
|
||||
def test_register_tool_class(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert "TestClassTool" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Registration Function
|
||||
def test_register_tool_fn(tool_registry):
|
||||
tool_registry.register_tool("test_fn", "/path/to/tool", tool_source_object=test_fn)
|
||||
assert "test_fn" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Existence Checks
|
||||
def test_has_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert tool_registry.has_tool("TestClassTool")
|
||||
assert not tool_registry.has_tool("NonexistentTool")
|
||||
|
||||
|
||||
# Test Tool Retrieval
|
||||
def test_get_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
tool = tool_registry.get_tool("TestClassTool")
|
||||
assert tool is not None
|
||||
assert tool.name == "TestClassTool"
|
||||
assert tool.path == "/path/to/tool"
|
||||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
44
tests/metagpt/utils/test_save_code.py
Normal file
44
tests/metagpt/utils/test_save_code.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/12/2023 4:17 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import nbformat
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.ci.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.utils.common import read_json_file
|
||||
from metagpt.utils.save_code import DATA_PATH, save_code_file
|
||||
|
||||
|
||||
def test_save_code_file_python():
|
||||
save_code_file("example", "print('Hello, World!')")
|
||||
file_path = DATA_PATH / "output" / "example" / "code.py"
|
||||
assert file_path.exists(), f"File does not exist: {file_path}"
|
||||
content = file_path.read_text()
|
||||
assert "print('Hello, World!')" in content, "File content does not match"
|
||||
|
||||
|
||||
def test_save_code_file_json():
|
||||
save_code_file("example_json", "print('Hello, JSON!')", file_format="json")
|
||||
file_path = DATA_PATH / "output" / "example_json" / "code.json"
|
||||
data = read_json_file(file_path)
|
||||
assert "code" in data, "JSON key 'code' is missing"
|
||||
assert data["code"] == "print('Hello, JSON!')", "JSON content does not match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_code_file_notebook():
|
||||
code = "print('Hello, World!')"
|
||||
executor = ExecuteNbCode()
|
||||
await executor.run(code)
|
||||
# Save as a Notebook file
|
||||
save_code_file("example_nb", executor.nb, file_format="ipynb")
|
||||
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
|
||||
assert file_path.exists(), f"Notebook file does not exist: {file_path}"
|
||||
|
||||
# Additional checks specific to notebook format
|
||||
notebook = nbformat.read(file_path, as_version=4)
|
||||
assert len(notebook.cells) > 0, "Notebook should have at least one cell"
|
||||
first_cell_source = notebook.cells[0].source
|
||||
assert "print('Hello, World!')" in first_cell_source, "Notebook cell content does not match"
|
||||
|
|
@ -1,13 +1,22 @@
|
|||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM
|
||||
|
||||
|
||||
class MockLLM(OpenAILLM):
|
||||
class MockLLM(OriginalLLM):
|
||||
def __init__(self, allow_open_api_call):
|
||||
super().__init__(config.get_openai_llm())
|
||||
original_llm_config = (
|
||||
config.get_openai_llm() if config.llm.api_type == LLMType.OPENAI else config.get_azure_llm()
|
||||
)
|
||||
super().__init__(original_llm_config)
|
||||
self.allow_open_api_call = allow_open_api_call
|
||||
self.rsp_cache: dict = {}
|
||||
self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list
|
||||
|
|
@ -62,6 +71,14 @@ class MockLLM(OpenAILLM):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""
|
||||
A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked.
|
||||
Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former.
|
||||
"""
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
@ -83,6 +100,12 @@ class MockLLM(OpenAILLM):
|
|||
rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
messages = self._process_message(messages)
|
||||
msg_key = json.dumps(messages, ensure_ascii=False)
|
||||
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs):
|
||||
if msg_key not in self.rsp_cache:
|
||||
if not self.allow_open_api_call:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue