mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
Merge branch 'main' into reasoning
This commit is contained in:
commit
08587f392f
339 changed files with 22997 additions and 1437 deletions
|
|
@ -10,7 +10,7 @@ from metagpt.utils.read_document import read_docx
|
|||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
count_input_tokens,
|
||||
count_message_tokens,
|
||||
count_output_tokens,
|
||||
)
|
||||
|
||||
|
|
@ -19,6 +19,8 @@ __all__ = [
|
|||
"read_docx",
|
||||
"Singleton",
|
||||
"TOKEN_COSTS",
|
||||
"count_input_tokens",
|
||||
"new_transaction_id",
|
||||
"count_message_tokens",
|
||||
"count_string_tokens",
|
||||
"count_output_tokens",
|
||||
]
|
||||
|
|
|
|||
312
metagpt/utils/a11y_tree.py
Normal file
312
metagpt/utils/a11y_tree.py
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
"""See https://github.com/web-arena-x/webarena
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from playwright.async_api import BrowserContext, Page
|
||||
|
||||
|
||||
async def get_accessibility_tree(page: Page):
|
||||
cdp_session = await get_page_cdp_session(page)
|
||||
resp = await cdp_session.send("Accessibility.getFullAXTree")
|
||||
|
||||
seen_ids = set()
|
||||
accessibility_tree = []
|
||||
for node in resp["nodes"]:
|
||||
if node["nodeId"] not in seen_ids:
|
||||
accessibility_tree.append(node)
|
||||
seen_ids.add(node["nodeId"])
|
||||
return accessibility_tree
|
||||
|
||||
|
||||
async def execute_step(step: str, page: Page, browser_ctx: BrowserContext, accessibility_tree: list):
|
||||
step = step.strip()
|
||||
func = step.split("[")[0].strip() if "[" in step else step.split()[0].strip()
|
||||
if func == "None":
|
||||
return ""
|
||||
elif func == "click":
|
||||
match = re.search(r"click ?\[(\d+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid click action {step}")
|
||||
element_id = match.group(1)
|
||||
await click_element(page, get_backend_node_id(element_id, accessibility_tree))
|
||||
elif func == "hover":
|
||||
match = re.search(r"hover ?\[(\d+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid hover action {step}")
|
||||
element_id = match.group(1)
|
||||
await hover_element(page, get_backend_node_id(element_id, accessibility_tree))
|
||||
elif func == "type":
|
||||
# add default enter flag
|
||||
if not (step.endswith("[0]") or step.endswith("[1]")):
|
||||
step += " [1]"
|
||||
|
||||
match = re.search(r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid type action {step}")
|
||||
element_id, text, enter_flag = (
|
||||
match.group(1),
|
||||
match.group(2),
|
||||
match.group(3),
|
||||
)
|
||||
if enter_flag == "1":
|
||||
text += "\n"
|
||||
await click_element(page, get_backend_node_id(element_id, accessibility_tree))
|
||||
await type_text(page, text)
|
||||
elif func == "press":
|
||||
match = re.search(r"press ?\[(.+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid press action {step}")
|
||||
key = match.group(1)
|
||||
await key_press(page, key)
|
||||
elif func == "scroll":
|
||||
# up or down
|
||||
match = re.search(r"scroll ?\[?(up|down)\]?", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid scroll action {step}")
|
||||
direction = match.group(1)
|
||||
await scroll_page(page, direction)
|
||||
elif func == "goto":
|
||||
match = re.search(r"goto ?\[(.+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid goto action {step}")
|
||||
url = match.group(1)
|
||||
await page.goto(url)
|
||||
elif func == "new_tab":
|
||||
page = await browser_ctx.new_page()
|
||||
elif func == "go_back":
|
||||
await page.go_back()
|
||||
elif func == "go_forward":
|
||||
await page.go_forward()
|
||||
elif func == "tab_focus":
|
||||
match = re.search(r"tab_focus ?\[(\d+)\]", step)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid tab_focus action {step}")
|
||||
page_number = int(match.group(1))
|
||||
page = browser_ctx.pages[page_number]
|
||||
await page.bring_to_front()
|
||||
elif func == "close_tab":
|
||||
await page.close()
|
||||
if len(browser_ctx.pages) > 0:
|
||||
page = browser_ctx.pages[-1]
|
||||
else:
|
||||
page = await browser_ctx.new_page()
|
||||
elif func == "stop":
|
||||
match = re.search(r'stop\(?"(.+)?"\)', step)
|
||||
answer = match.group(1) if match else ""
|
||||
return answer
|
||||
else:
|
||||
raise ValueError
|
||||
await page.wait_for_load_state("domcontentloaded")
|
||||
return page
|
||||
|
||||
|
||||
async def type_text(page: Page, text: str):
|
||||
await page.keyboard.type(text)
|
||||
|
||||
|
||||
async def click_element(page: Page, backend_node_id: int):
|
||||
cdp_session = await get_page_cdp_session(page)
|
||||
resp = await get_bounding_rect(cdp_session, backend_node_id)
|
||||
node_info = resp["result"]["value"]
|
||||
x, y = await get_element_center(node_info)
|
||||
# Move to the location of the element
|
||||
await page.evaluate(f"window.scrollTo({x}- window.innerWidth/2,{y} - window.innerHeight/2);")
|
||||
# Refresh the relative location of the element
|
||||
resp = await get_bounding_rect(cdp_session, backend_node_id)
|
||||
node_info = resp["result"]["value"]
|
||||
x, y = await get_element_center(node_info)
|
||||
await page.mouse.click(x, y)
|
||||
|
||||
|
||||
async def hover_element(page: Page, backend_node_id: int) -> None:
|
||||
cdp_session = await get_page_cdp_session(page)
|
||||
resp = await get_bounding_rect(cdp_session, backend_node_id)
|
||||
node_info = resp["result"]["value"]
|
||||
x, y = await get_element_center(node_info)
|
||||
await page.mouse.move(x, y)
|
||||
|
||||
|
||||
async def scroll_page(page: Page, direction: str) -> None:
|
||||
# perform the action
|
||||
# code from natbot
|
||||
if direction == "up":
|
||||
await page.evaluate(
|
||||
"(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
|
||||
)
|
||||
elif direction == "down":
|
||||
await page.evaluate(
|
||||
"(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
|
||||
)
|
||||
|
||||
|
||||
async def key_press(page: Page, key: str) -> None:
|
||||
"""Press a key."""
|
||||
if "Meta" in key and "Mac" not in await page.evaluate("navigator.platform"):
|
||||
key = key.replace("Meta", "Control")
|
||||
await page.keyboard.press(key)
|
||||
|
||||
|
||||
async def get_element_outer_html(page: Page, backend_node_id: int):
|
||||
cdp_session = await get_page_cdp_session(page)
|
||||
try:
|
||||
outer_html = await cdp_session.send("DOM.getOuterHTML", {"backendNodeId": int(backend_node_id)})
|
||||
return outer_html["outerHTML"]
|
||||
except Exception as e:
|
||||
raise ValueError("Element not found") from e
|
||||
|
||||
|
||||
async def get_element_center(node_info):
|
||||
x, y, width, height = node_info["x"], node_info["y"], node_info["width"], node_info["height"]
|
||||
center_x = x + width / 2
|
||||
center_y = y + height / 2
|
||||
return center_x, center_y
|
||||
|
||||
|
||||
def extract_step(response: str, action_splitter: str = "```") -> str:
|
||||
# find the first occurence of action
|
||||
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
|
||||
match = re.search(pattern, response)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
else:
|
||||
raise ValueError(f'Cannot find the answer phrase "{response}"')
|
||||
|
||||
|
||||
async def get_bounding_rect(cdp_session, backend_node_id: str):
|
||||
try:
|
||||
remote_object = await cdp_session.send("DOM.resolveNode", {"backendNodeId": int(backend_node_id)})
|
||||
remote_object_id = remote_object["object"]["objectId"]
|
||||
response = await cdp_session.send(
|
||||
"Runtime.callFunctionOn",
|
||||
{
|
||||
"objectId": remote_object_id,
|
||||
"functionDeclaration": """
|
||||
function() {
|
||||
if (this.nodeType == 3) {
|
||||
var range = document.createRange();
|
||||
range.selectNode(this);
|
||||
var rect = range.getBoundingClientRect().toJSON();
|
||||
range.detach();
|
||||
return rect;
|
||||
} else {
|
||||
return this.getBoundingClientRect().toJSON();
|
||||
}
|
||||
}
|
||||
""",
|
||||
"returnByValue": True,
|
||||
},
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise ValueError("Element not found") from e
|
||||
|
||||
|
||||
IGNORED_ACTREE_PROPERTIES = (
|
||||
"focusable",
|
||||
"editable",
|
||||
"readonly",
|
||||
"level",
|
||||
"settable",
|
||||
"multiline",
|
||||
"invalid",
|
||||
)
|
||||
|
||||
|
||||
def parse_accessibility_tree(accessibility_tree):
|
||||
"""Parse the accessibility tree into a string text"""
|
||||
node_id_to_idx = {}
|
||||
for idx, node in enumerate(accessibility_tree):
|
||||
node_id_to_idx[node["nodeId"]] = idx
|
||||
|
||||
obs_nodes_info = {}
|
||||
|
||||
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
|
||||
tree_str = ""
|
||||
node = accessibility_tree[idx]
|
||||
indent = "\t" * depth
|
||||
valid_node = True
|
||||
try:
|
||||
role = node["role"]["value"]
|
||||
name = node["name"]["value"]
|
||||
node_str = f"[{obs_node_id}] {role} {repr(name)}"
|
||||
properties = []
|
||||
for property in node.get("properties", []):
|
||||
try:
|
||||
if property["name"] in IGNORED_ACTREE_PROPERTIES:
|
||||
continue
|
||||
properties.append(f'{property["name"]}: {property["value"]["value"]}')
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if properties:
|
||||
node_str += " " + " ".join(properties)
|
||||
|
||||
# check valid
|
||||
if not node_str.strip():
|
||||
valid_node = False
|
||||
|
||||
# empty generic node
|
||||
if not name.strip():
|
||||
if not properties:
|
||||
if role in [
|
||||
"generic",
|
||||
"img",
|
||||
"list",
|
||||
"strong",
|
||||
"paragraph",
|
||||
"banner",
|
||||
"navigation",
|
||||
"Section",
|
||||
"LabelText",
|
||||
"Legend",
|
||||
"listitem",
|
||||
]:
|
||||
valid_node = False
|
||||
elif role in ["listitem"]:
|
||||
valid_node = False
|
||||
|
||||
if valid_node:
|
||||
tree_str += f"{indent}{node_str}"
|
||||
obs_nodes_info[obs_node_id] = {
|
||||
"backend_id": node["backendDOMNodeId"],
|
||||
"union_bound": node["union_bound"],
|
||||
"text": node_str,
|
||||
}
|
||||
|
||||
except Exception:
|
||||
valid_node = False
|
||||
|
||||
for _, child_node_id in enumerate(node["childIds"]):
|
||||
if child_node_id not in node_id_to_idx:
|
||||
continue
|
||||
# mark this to save some tokens
|
||||
child_depth = depth + 1 if valid_node else depth
|
||||
child_str = dfs(node_id_to_idx[child_node_id], child_node_id, child_depth)
|
||||
if child_str.strip():
|
||||
if tree_str.strip():
|
||||
tree_str += "\n"
|
||||
tree_str += child_str
|
||||
|
||||
return tree_str
|
||||
|
||||
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
|
||||
return tree_str, obs_nodes_info
|
||||
|
||||
|
||||
async def get_page_cdp_session(page):
|
||||
if hasattr(page, "cdp_session"):
|
||||
return page.cdp_session
|
||||
|
||||
cdp_session = await page.context.new_cdp_session(page)
|
||||
page.cdp_session = cdp_session
|
||||
return cdp_session
|
||||
|
||||
|
||||
def get_backend_node_id(element_id, accessibility_tree):
|
||||
element_id = str(element_id)
|
||||
for i in accessibility_tree:
|
||||
if i["nodeId"] == element_id:
|
||||
return i.get("backendDOMNodeId")
|
||||
raise ValueError(f"Element {element_id} not found")
|
||||
|
|
@ -15,6 +15,8 @@ import ast
|
|||
import base64
|
||||
import contextlib
|
||||
import csv
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -23,13 +25,19 @@ import os
|
|||
import platform
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from asyncio import iscoroutinefunction
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Literal, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import chardet
|
||||
import loguru
|
||||
import requests
|
||||
|
|
@ -37,9 +45,10 @@ from PIL import Image
|
|||
from pydantic_core import to_jsonable_python
|
||||
from tenacity import RetryCallState, RetryError, _utils
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
from metagpt.const import MARKDOWN_TITLE_PREFIX, MESSAGE_ROUTE_TO_ALL
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.json_to_markdown import json_to_markdown
|
||||
|
||||
|
||||
def check_cmd_exists(command) -> int:
|
||||
|
|
@ -65,7 +74,7 @@ class OutputParser:
|
|||
@classmethod
|
||||
def parse_blocks(cls, text: str):
|
||||
# 首先根据"##"将文本分割成不同的block
|
||||
blocks = text.split("##")
|
||||
blocks = text.split(MARKDOWN_TITLE_PREFIX)
|
||||
|
||||
# 创建一个字典,用于存储每个block的标题和内容
|
||||
block_dict = {}
|
||||
|
|
@ -271,10 +280,10 @@ class CodeParser:
|
|||
return block_dict
|
||||
|
||||
@classmethod
|
||||
def parse_code(cls, block: str, text: str, lang: str = "") -> str:
|
||||
def parse_code(cls, text: str, lang: str = "", block: Optional[str] = None) -> str:
|
||||
if block:
|
||||
text = cls.parse_block(block, text)
|
||||
pattern = rf"```{lang}.*?\s+(.*?)```"
|
||||
pattern = rf"```{lang}.*?\s+(.*?)\n```"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
|
|
@ -287,7 +296,7 @@ class CodeParser:
|
|||
|
||||
@classmethod
|
||||
def parse_str(cls, block: str, text: str, lang: str = ""):
|
||||
code = cls.parse_code(block, text, lang)
|
||||
code = cls.parse_code(block=block, text=text, lang=lang)
|
||||
code = code.split("=")[-1]
|
||||
code = code.strip().strip("'").strip('"')
|
||||
return code
|
||||
|
|
@ -295,7 +304,7 @@ class CodeParser:
|
|||
@classmethod
|
||||
def parse_file_list(cls, block: str, text: str, lang: str = "") -> list[str]:
|
||||
# Regular expression pattern to find the tasks list.
|
||||
code = cls.parse_code(block, text, lang)
|
||||
code = cls.parse_code(block=block, text=text, lang=lang)
|
||||
# print(code)
|
||||
pattern = r"\s*(.*=.*)?(\[.*\])"
|
||||
|
||||
|
|
@ -560,7 +569,7 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable
|
|||
return log_it
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
||||
def read_json_file(json_file: str, encoding: str = "utf-8") -> list[Any]:
|
||||
if not Path(json_file).exists():
|
||||
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
|
||||
|
||||
|
|
@ -572,13 +581,32 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
|||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
|
||||
def handle_unknown_serialization(x: Any) -> str:
|
||||
"""For `to_jsonable_python` debug, get more detail about the x."""
|
||||
|
||||
if inspect.ismethod(x):
|
||||
tip = f"Cannot serialize method '{x.__func__.__name__}' of class '{x.__self__.__class__.__name__}'"
|
||||
elif inspect.isfunction(x):
|
||||
tip = f"Cannot serialize function '{x.__name__}'"
|
||||
elif hasattr(x, "__class__"):
|
||||
tip = f"Cannot serialize instance of '{x.__class__.__name__}'"
|
||||
elif hasattr(x, "__name__"):
|
||||
tip = f"Cannot serialize class or module '{x.__name__}'"
|
||||
else:
|
||||
tip = f"Cannot serialize object of type '{type(x).__name__}'"
|
||||
|
||||
raise TypeError(tip)
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: Any, encoding: str = "utf-8", indent: int = 4, use_fallback: bool = False):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
|
||||
json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default)
|
||||
|
||||
|
||||
def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]:
|
||||
|
|
@ -670,7 +698,7 @@ def role_raise_decorator(func):
|
|||
raise Exception(format_trackback_info(limit=None))
|
||||
except Exception as e:
|
||||
if self.latest_observed_msg:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
"There is a exception in role's execution, in order to resume, "
|
||||
"we delete the newest role communication message in the role's memory."
|
||||
)
|
||||
|
|
@ -683,7 +711,7 @@ def role_raise_decorator(func):
|
|||
if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name):
|
||||
raise last_error
|
||||
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
raise Exception(format_trackback_info(limit=None)) from e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
@ -691,6 +719,8 @@ def role_raise_decorator(func):
|
|||
@handle_exception
|
||||
async def aread(filename: str | Path, encoding="utf-8") -> str:
|
||||
"""Read file asynchronously."""
|
||||
if not filename or not Path(filename).exists():
|
||||
return ""
|
||||
try:
|
||||
async with aiofiles.open(str(filename), mode="r", encoding=encoding) as reader:
|
||||
content = await reader.read()
|
||||
|
|
@ -810,13 +840,15 @@ def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None)
|
|||
return skills
|
||||
|
||||
|
||||
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
|
||||
def encode_image(image_path_or_pil: Union[Path, Image, str], encoding: str = "utf-8") -> str:
|
||||
"""encode image from file or PIL.Image into base64"""
|
||||
if isinstance(image_path_or_pil, Image.Image):
|
||||
buffer = BytesIO()
|
||||
image_path_or_pil.save(buffer, format="JPEG")
|
||||
bytes_data = buffer.getvalue()
|
||||
else:
|
||||
if isinstance(image_path_or_pil, str):
|
||||
image_path_or_pil = Path(image_path_or_pil)
|
||||
if not image_path_or_pil.exists():
|
||||
raise FileNotFoundError(f"{image_path_or_pil} not exists")
|
||||
with open(str(image_path_or_pil), "rb") as image_file:
|
||||
|
|
@ -838,6 +870,21 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
return img
|
||||
|
||||
|
||||
def extract_image_paths(content: str) -> bool:
|
||||
# We require that the path must have a space preceding it, like "xxx /an/absolute/path.jpg xxx"
|
||||
pattern = r"[^\s]+\.(?:png|jpe?g|gif|bmp|tiff|PNG|JPE?G|GIF|BMP|TIFF)"
|
||||
image_paths = re.findall(pattern, content)
|
||||
return image_paths
|
||||
|
||||
|
||||
def extract_and_encode_images(content: str) -> list[str]:
|
||||
images = []
|
||||
for path in extract_image_paths(content):
|
||||
if os.path.exists(path):
|
||||
images.append(encode_image(path))
|
||||
return images
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
@ -849,24 +896,335 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
def get_markdown_codeblock_type(filename: str) -> str:
|
||||
async def get_mime_type(filename: str | Path, force_read: bool = False) -> str:
|
||||
guess_mime_type, _ = mimetypes.guess_type(filename.name)
|
||||
if not guess_mime_type:
|
||||
ext_mappings = {".yml": "text/yaml", ".yaml": "text/yaml"}
|
||||
guess_mime_type = ext_mappings.get(filename.suffix)
|
||||
if not force_read and guess_mime_type:
|
||||
return guess_mime_type
|
||||
|
||||
from metagpt.tools.libs.shell import shell_execute # avoid circular import
|
||||
|
||||
text_set = {
|
||||
"application/json",
|
||||
"application/vnd.chipnuts.karaoke-mmd",
|
||||
"application/javascript",
|
||||
"application/xml",
|
||||
"application/x-sh",
|
||||
"application/sql",
|
||||
"text/yaml",
|
||||
}
|
||||
|
||||
try:
|
||||
stdout, stderr, _ = await shell_execute(f"file --mime-type '{str(filename)}'")
|
||||
if stderr:
|
||||
logger.debug(f"file:{filename}, error:{stderr}")
|
||||
return guess_mime_type
|
||||
ix = stdout.rfind(" ")
|
||||
mime_type = stdout[ix:].strip()
|
||||
if mime_type == "text/plain" and guess_mime_type in text_set:
|
||||
return guess_mime_type
|
||||
return mime_type
|
||||
except Exception as e:
|
||||
logger.debug(f"file:{filename}, error:{e}")
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_markdown_codeblock_type(filename: str = None, mime_type: str = None) -> str:
|
||||
"""Return the markdown code-block type corresponding to the file extension."""
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not filename and not mime_type:
|
||||
raise ValueError("Either filename or mime_type must be valid.")
|
||||
|
||||
if not mime_type:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
mappings = {
|
||||
"text/x-shellscript": "bash",
|
||||
"text/x-c++src": "cpp",
|
||||
"text/css": "css",
|
||||
"text/html": "html",
|
||||
"text/x-java": "java",
|
||||
"application/javascript": "javascript",
|
||||
"application/json": "json",
|
||||
"text/x-python": "python",
|
||||
"text/x-ruby": "ruby",
|
||||
"text/x-c": "cpp",
|
||||
"text/yaml": "yaml",
|
||||
"application/javascript": "javascript",
|
||||
"application/json": "json",
|
||||
"application/sql": "sql",
|
||||
"application/vnd.chipnuts.karaoke-mmd": "mermaid",
|
||||
"application/x-sh": "bash",
|
||||
"application/xml": "xml",
|
||||
}
|
||||
return mappings.get(mime_type, "text")
|
||||
|
||||
|
||||
def get_project_srcs_path(workdir: str | Path) -> Path:
|
||||
src_workdir_path = workdir / ".src_workspace"
|
||||
if src_workdir_path.exists():
|
||||
with open(src_workdir_path, "r") as file:
|
||||
src_name = file.read()
|
||||
else:
|
||||
src_name = Path(workdir).name
|
||||
return Path(workdir) / src_name
|
||||
|
||||
|
||||
async def init_python_folder(workdir: str | Path):
|
||||
if not workdir:
|
||||
return
|
||||
workdir = Path(workdir)
|
||||
if not workdir.exists():
|
||||
return
|
||||
init_filename = Path(workdir) / "__init__.py"
|
||||
if init_filename.exists():
|
||||
return
|
||||
async with aiofiles.open(init_filename, "a"):
|
||||
os.utime(init_filename, None)
|
||||
|
||||
|
||||
def get_markdown_code_block_type(filename: str) -> str:
|
||||
if not filename:
|
||||
return ""
|
||||
ext = Path(filename).suffix
|
||||
types = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".java": "java",
|
||||
".cpp": "cpp",
|
||||
".c": "c",
|
||||
".html": "html",
|
||||
".css": "css",
|
||||
".xml": "xml",
|
||||
".json": "json",
|
||||
".yaml": "yaml",
|
||||
".md": "markdown",
|
||||
".sql": "sql",
|
||||
".rb": "ruby",
|
||||
".php": "php",
|
||||
".sh": "bash",
|
||||
".swift": "swift",
|
||||
".go": "go",
|
||||
".rs": "rust",
|
||||
".pl": "perl",
|
||||
".asm": "assembly",
|
||||
".r": "r",
|
||||
".scss": "scss",
|
||||
".sass": "sass",
|
||||
".lua": "lua",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".yml": "yaml",
|
||||
".ini": "ini",
|
||||
".toml": "toml",
|
||||
".svg": "xml", # SVG can often be treated as XML
|
||||
# Add more file extensions and corresponding code block types as needed
|
||||
}
|
||||
return types.get(ext, "")
|
||||
|
||||
|
||||
def to_markdown_code_block(val: str, type_: str = "") -> str:
|
||||
"""
|
||||
Convert a string to a Markdown code block.
|
||||
|
||||
This function takes a string and wraps it in a Markdown code block.
|
||||
If a type is provided, it adds it as a language identifier for syntax highlighting.
|
||||
|
||||
Args:
|
||||
val (str): The string to be converted to a Markdown code block.
|
||||
type_ (str, optional): The language identifier for syntax highlighting.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
str: The input string wrapped in a Markdown code block.
|
||||
If the input string is empty, it returns an empty string.
|
||||
|
||||
Examples:
|
||||
>>> to_markdown_code_block("print('Hello, World!')", "python")
|
||||
\n```python\nprint('Hello, World!')\n```\n
|
||||
|
||||
>>> to_markdown_code_block("Some text")
|
||||
\n```\nSome text\n```\n
|
||||
"""
|
||||
if not val:
|
||||
return val or ""
|
||||
val = val.replace("```", "\\`\\`\\`")
|
||||
return f"\n```{type_}\n{val}\n```\n"
|
||||
|
||||
|
||||
async def save_json_to_markdown(content: str, output_filename: str | Path):
|
||||
"""
|
||||
Saves the provided JSON content as a Markdown file.
|
||||
|
||||
This function takes a JSON string, converts it to Markdown format,
|
||||
and writes it to the specified output file.
|
||||
|
||||
Args:
|
||||
content (str): The JSON content to be converted.
|
||||
output_filename (str or Path): The path where the output Markdown file will be saved.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None: Any exceptions are logged and the function returns without raising them.
|
||||
|
||||
Examples:
|
||||
>>> await save_json_to_markdown('{"key": "value"}', Path("/path/to/output.md"))
|
||||
This will save the Markdown converted JSON to the specified file.
|
||||
|
||||
Notes:
|
||||
- This function handles `json.JSONDecodeError` specifically for JSON parsing errors.
|
||||
- Any other exceptions during the process are also logged and handled gracefully.
|
||||
"""
|
||||
try:
|
||||
m = json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to decode JSON content: {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"An unexpected error occurred: {e}")
|
||||
return
|
||||
await awrite(filename=output_filename, data=json_to_markdown(m))
|
||||
|
||||
|
||||
def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a mapping of class methods to a given entry with class name as a prefix.
|
||||
|
||||
Args:
|
||||
cls: The class from which the methods are derived.
|
||||
methods (List[str]): A list of method names as strings.
|
||||
entry (Any): The entry to be mapped to each method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary where keys are method names prefixed with the class name and
|
||||
values are the given entry. If the number of methods is less than 2,
|
||||
the dictionary will contain a single entry with the class name as the key.
|
||||
|
||||
Example:
|
||||
>>> class MyClass:
|
||||
>>> pass
|
||||
>>>
|
||||
>>> tool2name(MyClass, ['method1', 'method2'], 'some_entry')
|
||||
{'MyClass.method1': 'some_entry', 'MyClass.method2': 'some_entry'}
|
||||
|
||||
>>> tool2name(MyClass, ['method1'], 'some_entry')
|
||||
{'MyClass': 'some_entry', 'MyClass.method1': 'some_entry'}
|
||||
"""
|
||||
class_name = cls.__name__
|
||||
mappings = {f"{class_name}.{i}": entry for i in methods}
|
||||
if len(mappings) < 2:
|
||||
mappings[class_name] = entry
|
||||
return mappings
|
||||
|
||||
|
||||
def new_transaction_id(postfix_len=8) -> str:
|
||||
"""
|
||||
Generates a new unique transaction ID based on current timestamp and a random UUID.
|
||||
|
||||
Args:
|
||||
postfix_len (int): Length of the random UUID postfix to include in the transaction ID. Default is 8.
|
||||
|
||||
Returns:
|
||||
str: A unique transaction ID composed of timestamp and a random UUID.
|
||||
"""
|
||||
return datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:postfix_len]
|
||||
|
||||
|
||||
def log_time(method):
|
||||
"""A time-consuming decorator for printing execution duration."""
|
||||
|
||||
def before_call():
|
||||
start_time, cpu_start_time = time.perf_counter(), time.process_time()
|
||||
logger.info(f"[{method.__name__}] started at: " f"{datetime.now().strftime('%Y-%m-%d %H:%m:%S')}")
|
||||
return start_time, cpu_start_time
|
||||
|
||||
def after_call(start_time, cpu_start_time):
|
||||
end_time, cpu_end_time = time.perf_counter(), time.process_time()
|
||||
logger.info(
|
||||
f"[{method.__name__}] ended. "
|
||||
f"Time elapsed: {end_time - start_time:.4} sec, CPU elapsed: {cpu_end_time - cpu_start_time:.4} sec"
|
||||
)
|
||||
|
||||
@functools.wraps(method)
|
||||
def timeit_wrapper(*args, **kwargs):
|
||||
start_time, cpu_start_time = before_call()
|
||||
result = method(*args, **kwargs)
|
||||
after_call(start_time, cpu_start_time)
|
||||
return result
|
||||
|
||||
@functools.wraps(method)
|
||||
async def timeit_wrapper_async(*args, **kwargs):
|
||||
start_time, cpu_start_time = before_call()
|
||||
result = await method(*args, **kwargs)
|
||||
after_call(start_time, cpu_start_time)
|
||||
return result
|
||||
|
||||
return timeit_wrapper_async if iscoroutinefunction(method) else timeit_wrapper
|
||||
|
||||
|
||||
async def check_http_endpoint(url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
Checks the status of an HTTP endpoint.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the HTTP endpoint to check.
|
||||
timeout (int, optional): The timeout in seconds for the HTTP request. Defaults to 3.
|
||||
|
||||
Returns:
|
||||
bool: True if the endpoint is online and responding with a 200 status code, False otherwise.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url, timeout=timeout) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
print(f"Error accessing the endpoint {url}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path:
|
||||
"""
|
||||
Rectifies the given path to ensure a valid output file path.
|
||||
|
||||
If the given `path` is a directory, it creates the directory (if it doesn't exist) and appends the `default_filename` to it. If the `path` is a file path, it creates the parent directory (if it doesn't exist) and returns the `path`.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): The input path, which can be a string or a `Path` object.
|
||||
default_filename (str): The default filename to use if the `path` is a directory.
|
||||
|
||||
Returns:
|
||||
Path: The rectified output path.
|
||||
"""
|
||||
output_pathname = Path(path)
|
||||
if output_pathname.is_dir():
|
||||
output_pathname.mkdir(parents=True, exist_ok=True)
|
||||
output_pathname = output_pathname / default_filename
|
||||
else:
|
||||
output_pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
return output_pathname
|
||||
|
||||
|
||||
def generate_fingerprint(text: str) -> str:
|
||||
"""
|
||||
Generate a fingerprint for the given text
|
||||
|
||||
Args:
|
||||
text (str): The text for which the fingerprint needs to be generated
|
||||
|
||||
Returns:
|
||||
str: The fingerprint value of the text
|
||||
"""
|
||||
text_bytes = text.encode("utf-8")
|
||||
|
||||
# calculate SHA-256 hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(text_bytes)
|
||||
fingerprint = sha256.hexdigest()
|
||||
|
||||
return fingerprint
|
||||
|
||||
|
||||
def download_model(file_url: str, target_folder: Path) -> Path:
|
||||
file_name = file_url.split("/")[-1]
|
||||
file_path = target_folder.joinpath(f"{file_name}")
|
||||
|
|
|
|||
|
|
@ -6,12 +6,19 @@
|
|||
@File : file.py
|
||||
@Describe : General file operations.
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import aiofiles
|
||||
from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils import read_docx
|
||||
from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.repo_to_markdown import is_text_file
|
||||
|
||||
|
||||
class File:
|
||||
|
|
@ -68,3 +75,127 @@ class File:
|
|||
content = b"".join(chunks)
|
||||
logger.debug(f"Successfully read file, the path of file: {file_path}")
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
async def is_textual_file(filename: Union[str, Path]) -> bool:
|
||||
"""Determines if a given file is a textual file.
|
||||
|
||||
A file is considered a textual file if it is plain text or has a
|
||||
specific set of MIME types associated with textual formats,
|
||||
including PDF and Microsoft Word documents.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The path to the file to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the file is a textual file, False otherwise.
|
||||
"""
|
||||
is_text, mime_type = await is_text_file(filename)
|
||||
if is_text:
|
||||
return True
|
||||
if mime_type == "application/pdf":
|
||||
return True
|
||||
if mime_type in {
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-word.document.macroEnabled.12",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
|
||||
"application/vnd.ms-word.template.macroEnabled.12",
|
||||
}:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def read_text_file(filename: Union[str, Path]) -> Optional[str]:
|
||||
"""Read the whole content of a file. Using absolute paths as the argument for specifying the file location."""
|
||||
is_text, mime_type = await is_text_file(filename)
|
||||
if is_text:
|
||||
return await File._read_text(filename)
|
||||
if mime_type == "application/pdf":
|
||||
return await File._read_pdf(filename)
|
||||
if mime_type in {
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-word.document.macroEnabled.12",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.template",
|
||||
"application/vnd.ms-word.template.macroEnabled.12",
|
||||
}:
|
||||
return await File._read_docx(filename)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _read_text(path: Union[str, Path]) -> str:
|
||||
return await aread(path)
|
||||
|
||||
@staticmethod
|
||||
async def _read_pdf(path: Union[str, Path]) -> str:
|
||||
result = await File._omniparse_read_file(path)
|
||||
if result:
|
||||
return result
|
||||
|
||||
from llama_index.readers.file import PDFReader
|
||||
|
||||
reader = PDFReader()
|
||||
lines = reader.load_data(file=Path(path))
|
||||
return "\n".join([i.text for i in lines])
|
||||
|
||||
@staticmethod
|
||||
async def _read_docx(path: Union[str, Path]) -> str:
|
||||
result = await File._omniparse_read_file(path)
|
||||
if result:
|
||||
return result
|
||||
return "\n".join(read_docx(str(path)))
|
||||
|
||||
@staticmethod
|
||||
async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]:
|
||||
from metagpt.tools.libs import get_env_default
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="")
|
||||
env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="")
|
||||
conf_base_url, conf_timeout = await File._read_omniparse_config()
|
||||
|
||||
base_url = env_base_url or conf_base_url
|
||||
if not base_url:
|
||||
return None
|
||||
api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="")
|
||||
timeout = env_timeout or conf_timeout or 600
|
||||
try:
|
||||
timeout = int(timeout)
|
||||
except ValueError:
|
||||
timeout = 600
|
||||
|
||||
try:
|
||||
if not await check_http_endpoint(url=base_url):
|
||||
logger.warning(f"{base_url}: NOT AVAILABLE")
|
||||
return None
|
||||
client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout)
|
||||
file_data = await aread_bin(filename=path)
|
||||
ret = await client.parse_document(file_input=file_data, bytes_filename=str(path))
|
||||
except (ValueError, Exception) as e:
|
||||
logger.exception(f"{path}: {e}")
|
||||
return None
|
||||
if not ret.images or not auto_save_image:
|
||||
return ret.text
|
||||
|
||||
result = [ret.text]
|
||||
img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images")
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in ret.images:
|
||||
byte_data = base64.b64decode(i.image)
|
||||
filename = img_dir / i.image_name
|
||||
await awrite_bin(filename=filename, data=byte_data)
|
||||
result.append(f"})")
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
async def _read_omniparse_config() -> Tuple[str, int]:
|
||||
if config.omniparse and config.omniparse.base_url:
|
||||
return config.omniparse.base_url, config.omniparse.timeout
|
||||
return "", 0
|
||||
|
||||
|
||||
class MemoryFileSystem(_MemoryFileSystem):
|
||||
@classmethod
|
||||
def _strip_protocol(cls, path):
|
||||
return super()._strip_protocol(str(path))
|
||||
|
|
|
|||
|
|
@ -198,8 +198,9 @@ class FileRepository:
|
|||
:type dependencies: List[str], optional
|
||||
"""
|
||||
|
||||
await self.save(filename=doc.filename, content=doc.content, dependencies=dependencies)
|
||||
doc = await self.save(filename=doc.filename, content=doc.content, dependencies=dependencies)
|
||||
logger.debug(f"File Saved: {str(doc.filename)}")
|
||||
return doc
|
||||
|
||||
async def save_pdf(self, doc: Document, with_suffix: str = ".md", dependencies: List[str] = None):
|
||||
"""Save a Document instance as a PDF file.
|
||||
|
|
@ -216,8 +217,9 @@ class FileRepository:
|
|||
"""
|
||||
m = json.loads(doc.content)
|
||||
filename = Path(doc.filename).with_suffix(with_suffix) if with_suffix is not None else Path(doc.filename)
|
||||
await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
|
||||
doc = await self.save(filename=str(filename), content=json_to_markdown(m), dependencies=dependencies)
|
||||
logger.debug(f"File Saved: {str(filename)}")
|
||||
return doc
|
||||
|
||||
async def delete(self, filename: Path | str):
|
||||
"""Delete a file from the file repository.
|
||||
|
|
|
|||
|
|
@ -8,16 +8,30 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from subprocess import TimeoutExpired
|
||||
from typing import Dict, List, Optional, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from git.repo import Repo
|
||||
from git.repo.fun import is_git_dir
|
||||
from github import Auth, BadCredentialsException, Github
|
||||
from github.GithubObject import NotSet
|
||||
from github.Issue import Issue
|
||||
from github.Label import Label
|
||||
from github.Milestone import Milestone
|
||||
from github.NamedUser import NamedUser
|
||||
from github.PullRequest import PullRequest
|
||||
from gitignore_parser import parse_gitignore
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.libs.shell import shell_execute
|
||||
from metagpt.utils.dependency_file import DependencyFile
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -32,6 +46,18 @@ class ChangeType(Enum):
|
|||
UNTRACTED = "U" # File is untracked (not added to version control)
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
def __init__(self, message="Rate limit exceeded"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class GitBranch(BaseModel):
|
||||
head: str
|
||||
base: str
|
||||
repo_name: str
|
||||
|
||||
|
||||
class GitRepository:
|
||||
"""A class representing a Git repository.
|
||||
|
||||
|
|
@ -52,7 +78,7 @@ class GitRepository:
|
|||
self._dependency = None
|
||||
self._gitignore_rules = None
|
||||
if local_path:
|
||||
self.open(local_path=local_path, auto_init=auto_init)
|
||||
self.open(local_path=Path(local_path), auto_init=auto_init)
|
||||
|
||||
def open(self, local_path: Path, auto_init=False):
|
||||
"""Open an existing Git repository or initialize a new one if auto_init is True.
|
||||
|
|
@ -68,7 +94,7 @@ class GitRepository:
|
|||
if not auto_init:
|
||||
return
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
return self._init(local_path)
|
||||
self._init(local_path)
|
||||
|
||||
def _init(self, local_path: Path):
|
||||
"""Initialize a new Git repository at the specified path.
|
||||
|
|
@ -130,6 +156,8 @@ class GitRepository:
|
|||
:param local_path: The local path to check.
|
||||
:return: True if the directory is a Git repository, False otherwise.
|
||||
"""
|
||||
if not local_path:
|
||||
return False
|
||||
git_dir = Path(local_path) / ".git"
|
||||
if git_dir.exists() and is_git_dir(git_dir):
|
||||
return True
|
||||
|
|
@ -160,15 +188,114 @@ class GitRepository:
|
|||
return None
|
||||
return Path(self._repository.working_dir)
|
||||
|
||||
@property
|
||||
def current_branch(self) -> str:
|
||||
"""
|
||||
Returns the name of the current active branch.
|
||||
|
||||
Returns:
|
||||
str: The name of the current active branch.
|
||||
"""
|
||||
return self._repository.active_branch.name
|
||||
|
||||
@property
|
||||
def remote_url(self) -> str:
|
||||
try:
|
||||
return self._repository.remotes.origin.url
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def repo_name(self) -> str:
|
||||
if self.remote_url:
|
||||
# This assumes a standard HTTPS or SSH format URL
|
||||
# HTTPS format example: https://github.com/username/repo_name.git
|
||||
# SSH format example: git@github.com:username/repo_name.git
|
||||
if self.remote_url.startswith("https://"):
|
||||
return self.remote_url.split("/", maxsplit=3)[-1].replace(".git", "")
|
||||
elif self.remote_url.startswith("git@"):
|
||||
return self.remote_url.split(":")[-1].replace(".git", "")
|
||||
return ""
|
||||
|
||||
def new_branch(self, branch_name: str) -> str:
|
||||
"""
|
||||
Creates a new branch with the given name.
|
||||
|
||||
Args:
|
||||
branch_name (str): The name of the new branch to create.
|
||||
|
||||
Returns:
|
||||
str: The name of the newly created branch.
|
||||
If the provided branch_name is empty, returns the name of the current active branch.
|
||||
"""
|
||||
if not branch_name:
|
||||
return self.current_branch
|
||||
new_branch = self._repository.create_head(branch_name)
|
||||
new_branch.checkout()
|
||||
return new_branch.name
|
||||
|
||||
def archive(self, comments="Archive"):
|
||||
"""Archive the current state of the Git repository.
|
||||
|
||||
:param comments: Comments for the archive commit.
|
||||
"""
|
||||
logger.info(f"Archive: {list(self.changed_files.keys())}")
|
||||
if not self.changed_files:
|
||||
return
|
||||
self.add_change(self.changed_files)
|
||||
self.commit(comments)
|
||||
|
||||
async def push(
|
||||
self, new_branch: str, comments="Archive", access_token: Optional[str] = None, auth: Optional[Auth] = None
|
||||
) -> GitBranch:
|
||||
"""
|
||||
Pushes changes to the remote repository.
|
||||
|
||||
Args:
|
||||
new_branch (str): The name of the new branch to be pushed.
|
||||
comments (str, optional): Comments to be associated with the push. Defaults to "Archive".
|
||||
access_token (str, optional): Access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`.
|
||||
auth (Auth, optional): Optional authentication object. Defaults to None.
|
||||
|
||||
Returns:
|
||||
GitBranch: The pushed branch object.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `auth` nor `access_token` is provided.
|
||||
BadCredentialsException: If authentication fails due to bad credentials or timeout.
|
||||
|
||||
Note:
|
||||
This function assumes that `self.current_branch`, `self.new_branch()`, `self.archive()`,
|
||||
`ctx.config.proxy`, `ctx.config`, `self.remote_url`, `shell_execute()`, and `logger` are
|
||||
defined and accessible within the scope of this function.
|
||||
"""
|
||||
if not auth and not access_token:
|
||||
raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"')
|
||||
from metagpt.context import Context
|
||||
|
||||
base = self.current_branch
|
||||
head = base if not new_branch else self.new_branch(new_branch)
|
||||
self.archive(comments) # will skip committing if no changes
|
||||
ctx = Context()
|
||||
env = ctx.new_environ()
|
||||
proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else []
|
||||
token = access_token or auth.token
|
||||
remote_url = f"https://{token}@" + self.remote_url.removeprefix("https://")
|
||||
command = ["git"] + proxy + ["push", remote_url]
|
||||
logger.info(" ".join(command).replace(token, "<TOKEN>"))
|
||||
try:
|
||||
stdout, stderr, return_code = await shell_execute(
|
||||
command=command, cwd=str(self.workdir), env=env, timeout=15
|
||||
)
|
||||
except TimeoutExpired as e:
|
||||
info = str(e).replace(token, "<TOKEN>")
|
||||
raise BadCredentialsException(status=401, message=info)
|
||||
info = f"{stdout}\n{stderr}\nexit: {return_code}\n"
|
||||
info = info.replace(token, "<TOKEN>")
|
||||
print(info)
|
||||
|
||||
return GitBranch(base=base, head=head, repo_name=self.repo_name)
|
||||
|
||||
def new_file_repository(self, relative_path: Path | str = ".") -> FileRepository:
|
||||
"""Create a new instance of FileRepository associated with this Git repository.
|
||||
|
||||
|
|
@ -248,6 +375,8 @@ class GitRepository:
|
|||
if not directory_path.exists():
|
||||
return []
|
||||
for file_path in directory_path.iterdir():
|
||||
if not file_path.is_relative_to(root_relative_path):
|
||||
continue
|
||||
if file_path.is_file():
|
||||
rpath = file_path.relative_to(root_relative_path)
|
||||
files.append(str(rpath))
|
||||
|
|
@ -283,3 +412,222 @@ class GitRepository:
|
|||
continue
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
@retry(wait=wait_random_exponential(min=1, max=15), stop=stop_after_attempt(3))
|
||||
async def clone_from(cls, url: str | Path, output_dir: str | Path = None) -> "GitRepository":
|
||||
from metagpt.context import Context
|
||||
|
||||
to_path = Path(output_dir or Path(__file__).parent / f"../../workspace/downloads/{uuid.uuid4().hex}").resolve()
|
||||
to_path.mkdir(parents=True, exist_ok=True)
|
||||
repo_dir = to_path / Path(url).stem
|
||||
if repo_dir.exists():
|
||||
shutil.rmtree(repo_dir, ignore_errors=True)
|
||||
ctx = Context()
|
||||
env = ctx.new_environ()
|
||||
proxy = ["-c", f"http.proxy={ctx.config.proxy}"] if ctx.config.proxy else []
|
||||
command = ["git", "clone"] + proxy + [str(url)]
|
||||
logger.info(" ".join(command))
|
||||
|
||||
stdout, stderr, return_code = await shell_execute(command=command, cwd=str(to_path), env=env, timeout=600)
|
||||
info = f"{stdout}\n{stderr}\nexit: {return_code}\n"
|
||||
logger.info(info)
|
||||
dir_name = Path(url).stem
|
||||
to_path = to_path / dir_name
|
||||
if not cls.is_git_dir(to_path):
|
||||
raise ValueError(info)
|
||||
logger.info(f"git clone to {to_path}")
|
||||
return GitRepository(local_path=to_path, auto_init=False)
|
||||
|
||||
async def checkout(self, commit_id: str):
|
||||
self._repository.git.checkout(commit_id)
|
||||
logger.info(f"git checkout {commit_id}")
|
||||
|
||||
def log(self) -> str:
|
||||
"""Return git log"""
|
||||
return self._repository.git.log()
|
||||
|
||||
@staticmethod
|
||||
async def create_pull(
|
||||
base: str,
|
||||
head: str,
|
||||
base_repo_name: str,
|
||||
head_repo_name: Optional[str] = None,
|
||||
*,
|
||||
title: Optional[str] = None,
|
||||
body: Optional[str] = None,
|
||||
maintainer_can_modify: Optional[bool] = None,
|
||||
draft: Optional[bool] = None,
|
||||
issue: Optional[Issue] = None,
|
||||
access_token: Optional[str] = None,
|
||||
auth: Optional[Auth] = None,
|
||||
) -> Union[PullRequest, str]:
|
||||
"""
|
||||
Creates a pull request in the specified repository.
|
||||
|
||||
Args:
|
||||
base (str): The name of the base branch.
|
||||
head (str): The name of the head branch.
|
||||
base_repo_name (str): The full repository name (user/repo) where the pull request will be created.
|
||||
head_repo_name (Optional[str], optional): The full repository name (user/repo) where the pull request will merge from. Defaults to None.
|
||||
title (Optional[str], optional): The title of the pull request. Defaults to None.
|
||||
body (Optional[str], optional): The body of the pull request. Defaults to None.
|
||||
maintainer_can_modify (Optional[bool], optional): Whether maintainers can modify the pull request. Defaults to None.
|
||||
draft (Optional[bool], optional): Whether the pull request is a draft. Defaults to None.
|
||||
issue (Optional[Issue], optional): The issue linked to the pull request. Defaults to None.
|
||||
access_token (Optional[str], optional): The access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`.
|
||||
auth (Optional[Auth], optional): The authentication method. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`
|
||||
|
||||
Returns:
|
||||
PullRequest: The created pull request object.
|
||||
"""
|
||||
title = title or NotSet
|
||||
body = body or NotSet
|
||||
maintainer_can_modify = maintainer_can_modify or NotSet
|
||||
draft = draft or NotSet
|
||||
issue = issue or NotSet
|
||||
if not auth and not access_token:
|
||||
raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"')
|
||||
clone_url = f"https://github.com/{base_repo_name}.git"
|
||||
try:
|
||||
auth = auth or Auth.Token(access_token)
|
||||
g = Github(auth=auth)
|
||||
base_repo = g.get_repo(base_repo_name)
|
||||
clone_url = base_repo.clone_url
|
||||
head_repo = g.get_repo(head_repo_name) if head_repo_name and head_repo_name != base_repo_name else None
|
||||
if head_repo:
|
||||
user = head_repo.full_name.split("/")[0]
|
||||
head = f"{user}:{head}"
|
||||
pr = base_repo.create_pull(
|
||||
base=base,
|
||||
head=head,
|
||||
title=title,
|
||||
body=body,
|
||||
maintainer_can_modify=maintainer_can_modify,
|
||||
draft=draft,
|
||||
issue=issue,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pull Request Error: {e}")
|
||||
return GitRepository.create_github_pull_url(
|
||||
clone_url=clone_url,
|
||||
base=base,
|
||||
head=head,
|
||||
head_repo_name=head_repo_name,
|
||||
)
|
||||
return pr
|
||||
|
||||
@staticmethod
|
||||
async def create_issue(
|
||||
repo_name: str,
|
||||
title: str,
|
||||
body: Optional[str] = None,
|
||||
assignee: NamedUser | Optional[str] = None,
|
||||
milestone: Optional[Milestone] = None,
|
||||
labels: list[Label] | Optional[list[str]] = None,
|
||||
assignees: Optional[list[str]] | list[NamedUser] = None,
|
||||
access_token: Optional[str] = None,
|
||||
auth: Optional[Auth] = None,
|
||||
) -> Issue:
|
||||
"""
|
||||
Creates an issue in the specified repository.
|
||||
|
||||
Args:
|
||||
repo_name (str): The full repository name (user/repo) where the issue will be created.
|
||||
title (str): The title of the issue.
|
||||
body (Optional[str], optional): The body of the issue. Defaults to None.
|
||||
assignee (Union[NamedUser, str], optional): The assignee for the issue, either as a NamedUser object or their username. Defaults to None.
|
||||
milestone (Optional[Milestone], optional): The milestone to associate with the issue. Defaults to None.
|
||||
labels (Union[list[Label], list[str]], optional): The labels to associate with the issue, either as Label objects or their names. Defaults to None.
|
||||
assignees (Union[list[str], list[NamedUser]], optional): The list of usernames or NamedUser objects to assign to the issue. Defaults to None.
|
||||
access_token (Optional[str], optional): The access token for authentication. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`, `https://github.com/PyGithub/PyGithub/blob/main/doc/examples/Authentication.rst`.
|
||||
auth (Optional[Auth], optional): The authentication method. Defaults to None. Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html`
|
||||
|
||||
Returns:
|
||||
Issue: The created issue object.
|
||||
"""
|
||||
body = body or NotSet
|
||||
assignee = assignee or NotSet
|
||||
milestone = milestone or NotSet
|
||||
labels = labels or NotSet
|
||||
assignees = assignees or NotSet
|
||||
if not auth and not access_token:
|
||||
raise ValueError('`access_token` is invalid. Visit: "https://github.com/settings/tokens"')
|
||||
auth = auth or Auth.Token(access_token)
|
||||
g = Github(auth=auth)
|
||||
|
||||
repo = g.get_repo(repo_name)
|
||||
x_ratelimit_remaining = repo.raw_headers.get("x-ratelimit-remaining")
|
||||
if (
|
||||
x_ratelimit_remaining
|
||||
and bool(re.match(r"^-?\d+$", x_ratelimit_remaining))
|
||||
and int(x_ratelimit_remaining) <= 0
|
||||
):
|
||||
raise RateLimitError()
|
||||
issue = repo.create_issue(
|
||||
title=title,
|
||||
body=body,
|
||||
assignee=assignee,
|
||||
milestone=milestone,
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
)
|
||||
return issue
|
||||
|
||||
@staticmethod
|
||||
async def get_repos(access_token: Optional[str] = None, auth: Optional[Auth] = None) -> List[str]:
|
||||
"""
|
||||
Fetches a list of public repositories belonging to the authenticated user.
|
||||
|
||||
Args:
|
||||
access_token (Optional[str], optional): The access token for authentication. Defaults to None.
|
||||
Visit `https://github.com/settings/tokens` for obtaining a personal access token.
|
||||
auth (Optional[Auth], optional): The authentication method. Defaults to None.
|
||||
Visit `https://pygithub.readthedocs.io/en/latest/examples/Authentication.html` for more information.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of full names of the public repositories belonging to the user.
|
||||
"""
|
||||
auth = auth or Auth.Token(access_token)
|
||||
git = Github(auth=auth)
|
||||
user = git.get_user()
|
||||
v = user.get_repos(visibility="public")
|
||||
return [i.full_name for i in v]
|
||||
|
||||
@staticmethod
|
||||
def create_github_pull_url(clone_url: str, base: str, head: str, head_repo_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Create a URL for comparing changes between branches or repositories on GitHub.
|
||||
|
||||
Args:
|
||||
clone_url (str): The URL used for cloning the repository, ending with '.git'.
|
||||
base (str): The base branch or commit.
|
||||
head (str): The head branch or commit.
|
||||
head_repo_name (str, optional): The name of the repository for the head branch. If not provided, assumes the same repository.
|
||||
|
||||
Returns:
|
||||
str: The URL for comparing changes between the specified branches or commits.
|
||||
"""
|
||||
url = clone_url.removesuffix(".git") + f"/compare/{base}..."
|
||||
if head_repo_name:
|
||||
url += head_repo_name.replace("/", ":")
|
||||
url += ":" + head
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def create_gitlab_merge_request_url(clone_url: str, head: str) -> str:
|
||||
"""
|
||||
Create a URL for creating a new merge request on GitLab.
|
||||
|
||||
Args:
|
||||
clone_url (str): The URL used for cloning the repository, ending with '.git'.
|
||||
head (str): The name of the branch to be merged.
|
||||
|
||||
Returns:
|
||||
str: The URL for creating a new merge request for the specified branch.
|
||||
"""
|
||||
return (
|
||||
clone_url.removesuffix(".git")
|
||||
+ "/-/merge_requests/new?merge_request%5Bsource_branch%5D="
|
||||
+ quote(head, safe="")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ class GraphKeyword:
|
|||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
HAS_PARTICIPANT = "has_participant"
|
||||
HAS_SUMMARY = "has_summary"
|
||||
HAS_INSTALL = "has_install"
|
||||
HAS_CONFIG = "has_config"
|
||||
HAS_USAGE = "has_usage"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
|
|
|
|||
32
metagpt/utils/make_sk_kernel.py
Normal file
32
metagpt/utils/make_sk_kernel.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/9/13 12:29
|
||||
@Author : femto Zheng
|
||||
@File : make_sk_kernel.py
|
||||
"""
|
||||
import semantic_kernel as sk
|
||||
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import (
|
||||
AzureChatCompletion,
|
||||
)
|
||||
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import (
|
||||
OpenAIChatCompletion,
|
||||
)
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if llm := config.get_azure_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
|
||||
)
|
||||
elif llm := config.get_openai_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(llm.model, llm.api_key),
|
||||
)
|
||||
|
||||
return kernel
|
||||
|
|
@ -7,23 +7,44 @@
|
|||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import awrite, check_cmd_exists
|
||||
|
||||
|
||||
async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""suffix: png/svg/pdf
|
||||
async def mermaid_to_file(
|
||||
engine,
|
||||
mermaid_code,
|
||||
output_file_without_suffix,
|
||||
width=2048,
|
||||
height=2048,
|
||||
config=None,
|
||||
suffixes: Optional[List[str]] = None,
|
||||
) -> int:
|
||||
"""Convert Mermaid code to various file formats.
|
||||
|
||||
:param mermaid_code: mermaid code
|
||||
:param output_file_without_suffix: output filename
|
||||
:param width:
|
||||
:param height:
|
||||
:return: 0 if succeed, -1 if failed
|
||||
Args:
|
||||
engine (str): The engine to use for conversion. Supported engines are "nodejs", "playwright", "pyppeteer", "ink", and "none".
|
||||
mermaid_code (str): The Mermaid code to be converted.
|
||||
output_file_without_suffix (str): The output file name without the suffix.
|
||||
width (int, optional): The width of the output image. Defaults to 2048.
|
||||
height (int, optional): The height of the output image. Defaults to 2048.
|
||||
config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration.
|
||||
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
|
||||
|
||||
Returns:
|
||||
int: 0 if the conversion is successful, -1 if the conversion fails.
|
||||
"""
|
||||
file_head = "%%{init: {'theme': 'default', 'themeVariables': { 'fontFamily': 'Inter' }}}%%\n"
|
||||
if not re.match(r"^%%\{.+", mermaid_code):
|
||||
mermaid_code = file_head + mermaid_code
|
||||
suffixes = suffixes or ["svg"]
|
||||
# Write the Mermaid code to a temporary file
|
||||
config = config if config else Config.default()
|
||||
dir_name = os.path.dirname(output_file_without_suffix)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
|
@ -38,7 +59,7 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
)
|
||||
return -1
|
||||
|
||||
for suffix in ["pdf", "svg", "png"]:
|
||||
for suffix in suffixes:
|
||||
output_file = f"{output_file_without_suffix}.{suffix}"
|
||||
# Call the `mmdc` command to convert the Mermaid code to a PNG
|
||||
logger.info(f"Generating {output_file}..")
|
||||
|
|
@ -72,15 +93,15 @@ async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, widt
|
|||
if engine == "playwright":
|
||||
from metagpt.utils.mmdc_playwright import mermaid_to_file
|
||||
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height)
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes)
|
||||
elif engine == "pyppeteer":
|
||||
from metagpt.utils.mmdc_pyppeteer import mermaid_to_file
|
||||
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height)
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height, suffixes=suffixes)
|
||||
elif engine == "ink":
|
||||
from metagpt.utils.mmdc_ink import mermaid_to_file
|
||||
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix)
|
||||
return await mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes=suffixes)
|
||||
elif engine == "none":
|
||||
return 0
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -6,21 +6,29 @@
|
|||
@File : mermaid.py
|
||||
"""
|
||||
import base64
|
||||
from typing import List, Optional
|
||||
|
||||
from aiohttp import ClientError, ClientSession
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix):
|
||||
"""suffix: png/svg
|
||||
:param mermaid_code: mermaid code
|
||||
:param output_file_without_suffix: output filename without suffix
|
||||
:return: 0 if succeed, -1 if failed
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, suffixes: Optional[List[str]] = None):
|
||||
"""Convert Mermaid code to various file formats.
|
||||
|
||||
Args:
|
||||
mermaid_code (str): The Mermaid code to be converted.
|
||||
output_file_without_suffix (str): The output file name without the suffix.
|
||||
width (int, optional): The width of the output image. Defaults to 2048.
|
||||
height (int, optional): The height of the output image. Defaults to 2048.
|
||||
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
|
||||
|
||||
Returns:
|
||||
int: 0 if the conversion is successful, -1 if the conversion fails.
|
||||
"""
|
||||
encoded_string = base64.b64encode(mermaid_code.encode()).decode()
|
||||
|
||||
for suffix in ["svg", "png"]:
|
||||
suffixes = suffixes or ["png"]
|
||||
for suffix in suffixes:
|
||||
output_file = f"{output_file_without_suffix}.{suffix}"
|
||||
path_type = "svg" if suffix == "svg" else "img"
|
||||
url = f"https://mermaid.ink/{path_type}/{encoded_string}"
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
|
@ -14,20 +15,22 @@ from playwright.async_api import async_playwright
|
|||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
async def mermaid_to_file(
|
||||
mermaid_code, output_file_without_suffix, width=2048, height=2048, suffixes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""Convert Mermaid code to various file formats.
|
||||
|
||||
Args:
|
||||
mermaid_code (str): The Mermaid code to convert.
|
||||
output_file_without_suffix (str): The output file name without the file extension.
|
||||
width (int, optional): The width of the output image in pixels. Defaults to 2048.
|
||||
height (int, optional): The height of the output image in pixels. Defaults to 2048.
|
||||
mermaid_code (str): The Mermaid code to be converted.
|
||||
output_file_without_suffix (str): The output file name without the suffix.
|
||||
width (int, optional): The width of the output image. Defaults to 2048.
|
||||
height (int, optional): The height of the output image. Defaults to 2048.
|
||||
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
|
||||
|
||||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
int: 0 if the conversion is successful, -1 if the conversion fails.
|
||||
"""
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
suffixes = suffixes or ["png"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
async with async_playwright() as p:
|
||||
|
|
|
|||
|
|
@ -6,28 +6,33 @@
|
|||
@File : mmdc_pyppeteer.py
|
||||
"""
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from pyppeteer import launch
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int:
|
||||
"""
|
||||
Converts the given Mermaid code to various output formats and saves them to files.
|
||||
async def mermaid_to_file(
|
||||
mermaid_code, output_file_without_suffix, width=2048, height=2048, config=None, suffixes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""Convert Mermaid code to various file formats.
|
||||
|
||||
Args:
|
||||
mermaid_code (str): The Mermaid code to convert.
|
||||
output_file_without_suffix (str): The output file name without the file extension.
|
||||
width (int, optional): The width of the output image in pixels. Defaults to 2048.
|
||||
height (int, optional): The height of the output image in pixels. Defaults to 2048.
|
||||
mermaid_code (str): The Mermaid code to be converted.
|
||||
output_file_without_suffix (str): The output file name without the suffix.
|
||||
width (int, optional): The width of the output image. Defaults to 2048.
|
||||
height (int, optional): The height of the output image. Defaults to 2048.
|
||||
config (Optional[Config], optional): The configuration to use for the conversion. Defaults to None, which uses the default configuration.
|
||||
suffixes (Optional[List[str]], optional): The file suffixes to generate. Supports "png", "pdf", and "svg". Defaults to ["png"].
|
||||
|
||||
Returns:
|
||||
int: Returns 1 if the conversion and saving were successful, -1 otherwise.
|
||||
int: 0 if the conversion is successful, -1 if the conversion fails.
|
||||
"""
|
||||
suffixes = ["png", "svg", "pdf"]
|
||||
config = config if config else Config.default()
|
||||
suffixes = suffixes or ["png"]
|
||||
__dirname = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
if config.mermaid.pyppeteer_path:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
|
@ -190,7 +189,7 @@ class OmniParseClient:
|
|||
# Do not verify if only byte data is provided
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(verify_file_path)[1].lower()
|
||||
file_ext = Path(verify_file_path).suffix.lower()
|
||||
if file_ext not in allowed_file_extensions:
|
||||
raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}")
|
||||
|
||||
|
|
@ -219,7 +218,7 @@ class OmniParseClient:
|
|||
Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type).
|
||||
"""
|
||||
if isinstance(file_input, (str, Path)):
|
||||
filename = os.path.basename(str(file_input))
|
||||
filename = Path(file_input).name
|
||||
file_bytes = await aread_bin(file_input)
|
||||
|
||||
if only_bytes:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
from typing import Generator, Optional
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import htmlmin
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
|
@ -38,6 +39,22 @@ class WebPage(BaseModel):
|
|||
elif url.startswith(("http://", "https://")):
|
||||
yield urljoin(self.url, url)
|
||||
|
||||
def get_slim_soup(self, keep_links: bool = False):
|
||||
soup = _get_soup(self.html)
|
||||
keep_attrs = ["class", "id"]
|
||||
if keep_links:
|
||||
keep_attrs.append("href")
|
||||
|
||||
for i in soup.find_all(True):
|
||||
for name in list(i.attrs):
|
||||
if i[name] and name not in keep_attrs:
|
||||
del i[name]
|
||||
|
||||
for i in soup.find_all(["svg", "img", "video", "audio"]):
|
||||
i.decompose()
|
||||
|
||||
return soup
|
||||
|
||||
|
||||
def get_html_content(page: str, base: str):
|
||||
soup = _get_soup(page)
|
||||
|
|
@ -48,7 +65,12 @@ def get_html_content(page: str, base: str):
|
|||
def _get_soup(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
|
||||
for s in soup(["style", "script", "[document]", "head", "title"]):
|
||||
for s in soup(["style", "script", "[document]", "head", "title", "footer"]):
|
||||
s.extract()
|
||||
|
||||
return soup
|
||||
|
||||
|
||||
def simplify_html(html: str, url: str, keep_links: bool = False):
|
||||
html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode()
|
||||
return htmlmin.minify(html, remove_comments=True, remove_empty_space=True)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.const import (
|
||||
CLASS_VIEW_FILE_REPO,
|
||||
|
|
@ -35,6 +36,7 @@ from metagpt.const import (
|
|||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.common import get_project_srcs_path
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
|
@ -129,22 +131,33 @@ class ProjectRepo(FileRepository):
|
|||
return self._git_repo.new_file_repository(self._srcs_path)
|
||||
|
||||
def code_files_exists(self) -> bool:
|
||||
git_workdir = self.git_repo.workdir
|
||||
src_workdir = git_workdir / git_workdir.name
|
||||
src_workdir = get_project_srcs_path(self.git_repo.workdir)
|
||||
if not src_workdir.exists():
|
||||
return False
|
||||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
code_files = self.with_src_path(path=src_workdir).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
return bool(code_files)
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
self._srcs_path = Path(path).relative_to(self.workdir)
|
||||
except ValueError:
|
||||
self._srcs_path = Path(path)
|
||||
path = Path(path)
|
||||
if path.is_relative_to(self.workdir):
|
||||
self._srcs_path = path.relative_to(self.workdir)
|
||||
else:
|
||||
self._srcs_path = path
|
||||
return self
|
||||
|
||||
@property
|
||||
def src_relative_path(self) -> Path | None:
|
||||
return self._srcs_path
|
||||
|
||||
@staticmethod
|
||||
def search_project_path(filename: str | Path) -> Optional[Path]:
|
||||
root = Path(filename).parent if Path(filename).is_file() else Path(filename)
|
||||
root = root.resolve()
|
||||
while str(root) != "/":
|
||||
git_repo = root / ".git"
|
||||
if git_repo.exists():
|
||||
return root
|
||||
root = root.parent
|
||||
return None
|
||||
|
|
|
|||
19
metagpt/utils/proxy_env.py
Normal file
19
metagpt/utils/proxy_env.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import os
|
||||
|
||||
|
||||
def get_proxy_from_env():
|
||||
proxy_config = {}
|
||||
server = None
|
||||
for i in ("ALL_PROXY", "all_proxy", "HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy"):
|
||||
if os.environ.get(i):
|
||||
server = os.environ.get(i)
|
||||
if server:
|
||||
proxy_config["server"] = server
|
||||
no_proxy = os.environ.get("NO_PROXY") or os.environ.get("no_proxy")
|
||||
if no_proxy:
|
||||
proxy_config["bypass"] = no_proxy
|
||||
|
||||
if not proxy_config:
|
||||
proxy_config = None
|
||||
|
||||
return proxy_config
|
||||
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import regex as re
|
||||
from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.custom_decoder import CustomDecoder
|
||||
|
||||
|
|
@ -154,7 +154,9 @@ def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType =
|
|||
return output
|
||||
|
||||
|
||||
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
|
||||
def repair_llm_raw_output(
|
||||
output: str, req_keys: list[str], repair_type: RepairType = None, config: Optional[Config] = None
|
||||
) -> str:
|
||||
"""
|
||||
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
|
||||
so here we try to repair it and use all repair methods by default.
|
||||
|
|
@ -169,6 +171,7 @@ def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairT
|
|||
target: { xxx }
|
||||
output: { xxx }]
|
||||
"""
|
||||
config = config if config else Config.default()
|
||||
if not config.repair_llm_output:
|
||||
return output
|
||||
|
||||
|
|
@ -256,6 +259,7 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
"next_action":"None"
|
||||
}
|
||||
"""
|
||||
config = Config.default()
|
||||
if retry_state.outcome.failed:
|
||||
if retry_state.args:
|
||||
# # can't be used as args=retry_state.args
|
||||
|
|
@ -276,8 +280,12 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
return run_and_passon
|
||||
|
||||
|
||||
def repair_stop_after_attempt(retry_state):
|
||||
return stop_after_attempt(3 if Config.default().repair_llm_output else 0)(retry_state)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3 if config.repair_llm_output else 0),
|
||||
stop=repair_stop_after_attempt,
|
||||
wait=wait_fixed(1),
|
||||
after=run_after_exp_and_passon_next_retry(logger),
|
||||
)
|
||||
|
|
@ -347,3 +355,44 @@ def extract_state_value_from_output(content: str) -> str:
|
|||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
|
||||
|
||||
def repair_escape_error(commands):
|
||||
"""
|
||||
Repaires escape errors in command responses.
|
||||
When RoleZero parses a command, the command may contain unknown escape characters.
|
||||
|
||||
This function has two steps:
|
||||
1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(".
|
||||
2. Transform escaped characters like '\f' to substrings like "\\\\f".
|
||||
|
||||
Example:
|
||||
When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ",
|
||||
The "content" will be parsed correctly to "\( \frac{1}{2} \)".
|
||||
|
||||
However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly.
|
||||
It will cause a parsing error.
|
||||
|
||||
To repair the wrong JSON string, the following transformations will be used:
|
||||
"\(" ---> "\\\\("
|
||||
'\f' ---> "\\\\f"
|
||||
"\)" ---> "\\\\)"
|
||||
|
||||
"""
|
||||
escape_repair_map = {
|
||||
"\a": "\\\\a",
|
||||
"\b": "\\\\b",
|
||||
"\f": "\\\\f",
|
||||
"\r": "\\\\r",
|
||||
"\t": "\\\\t",
|
||||
"\v": "\\\\v",
|
||||
}
|
||||
new_command = ""
|
||||
for index, ch in enumerate(commands):
|
||||
if ch == "\\" and index + 1 < len(commands):
|
||||
if commands[index + 1] not in ["n", '"', " "]:
|
||||
new_command += "\\"
|
||||
elif ch in escape_repair_map:
|
||||
ch = escape_repair_map[ch]
|
||||
new_command += ch
|
||||
return new_command
|
||||
|
|
|
|||
|
|
@ -5,17 +5,24 @@ This file provides functionality to convert a local repository into a markdown r
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, awrite, get_markdown_codeblock_type, list_files
|
||||
from metagpt.utils.common import (
|
||||
aread,
|
||||
awrite,
|
||||
get_markdown_codeblock_type,
|
||||
get_mime_type,
|
||||
list_files,
|
||||
)
|
||||
from metagpt.utils.tree import tree
|
||||
|
||||
|
||||
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str:
|
||||
async def repo_to_markdown(repo_path: str | Path, output: str | Path = None) -> str:
|
||||
"""
|
||||
Convert a local repository into a markdown representation.
|
||||
|
||||
|
|
@ -25,56 +32,118 @@ async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, git
|
|||
Args:
|
||||
repo_path (str | Path): The path to the local repository.
|
||||
output (str | Path, optional): The path to save the generated markdown file. Defaults to None.
|
||||
gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The markdown representation of the repository.
|
||||
"""
|
||||
repo_path = Path(repo_path)
|
||||
gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve()
|
||||
repo_path = Path(repo_path).resolve()
|
||||
gitignore_file = repo_path / ".gitignore"
|
||||
|
||||
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore)
|
||||
markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore_file)
|
||||
|
||||
gitignore_rules = parse_gitignore(full_path=str(gitignore))
|
||||
gitignore_rules = parse_gitignore(full_path=str(gitignore_file)) if gitignore_file.exists() else None
|
||||
markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules)
|
||||
|
||||
if output:
|
||||
await awrite(filename=str(output), data=markdown, encoding="utf-8")
|
||||
output_file = Path(output).resolve()
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
await awrite(filename=str(output_file), data=markdown, encoding="utf-8")
|
||||
logger.info(f"save: {output_file}")
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str:
|
||||
try:
|
||||
content = tree(repo_path, gitignore, run_command=True)
|
||||
content = await tree(repo_path, gitignore, run_command=True)
|
||||
except Exception as e:
|
||||
logger.info(f"{e}, using safe mode.")
|
||||
content = tree(repo_path, gitignore, run_command=False)
|
||||
content = await tree(repo_path, gitignore, run_command=False)
|
||||
|
||||
doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n"
|
||||
return doc
|
||||
|
||||
|
||||
async def _write_files(repo_path, gitignore_rules) -> str:
|
||||
async def _write_files(repo_path, gitignore_rules=None) -> str:
|
||||
filenames = list_files(repo_path)
|
||||
markdown = ""
|
||||
pattern = r"^\..*" # Hidden folders/files
|
||||
for filename in filenames:
|
||||
if gitignore_rules(str(filename)):
|
||||
if gitignore_rules and gitignore_rules(str(filename)):
|
||||
continue
|
||||
ignore = False
|
||||
for i in filename.parts:
|
||||
if re.match(pattern, i):
|
||||
ignore = True
|
||||
break
|
||||
if ignore:
|
||||
continue
|
||||
markdown += await _write_file(filename=filename, repo_path=repo_path)
|
||||
return markdown
|
||||
|
||||
|
||||
async def _write_file(filename: Path, repo_path: Path) -> str:
|
||||
relative_path = filename.relative_to(repo_path)
|
||||
markdown = f"## {relative_path}\n"
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(filename.name)
|
||||
if "text/" not in mime_type:
|
||||
is_text, mime_type = await is_text_file(filename)
|
||||
if not is_text:
|
||||
logger.info(f"Ignore content: {filename}")
|
||||
markdown += "<binary file>\n---\n\n"
|
||||
return ""
|
||||
|
||||
try:
|
||||
relative_path = filename.relative_to(repo_path)
|
||||
markdown = f"## {relative_path}\n"
|
||||
content = await aread(filename, encoding="utf-8")
|
||||
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
|
||||
code_block_type = get_markdown_codeblock_type(filename.name)
|
||||
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
|
||||
return markdown
|
||||
content = await aread(filename, encoding="utf-8")
|
||||
content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")
|
||||
code_block_type = get_markdown_codeblock_type(filename.name)
|
||||
markdown += f"```{code_block_type}\n{content}\n```\n---\n\n"
|
||||
return markdown
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return ""
|
||||
|
||||
|
||||
async def is_text_file(filename: Union[str, Path]) -> Tuple[bool, str]:
|
||||
"""
|
||||
Determines if the specified file is a text file based on its MIME type.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The path to the file.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: A tuple where the first element indicates if the file is a text file
|
||||
(True for text file, False otherwise), and the second element is the MIME type of the file.
|
||||
"""
|
||||
pass_set = {
|
||||
"application/json",
|
||||
"application/vnd.chipnuts.karaoke-mmd",
|
||||
"application/javascript",
|
||||
"application/xml",
|
||||
"application/x-sh",
|
||||
"application/sql",
|
||||
}
|
||||
denied_set = {
|
||||
"application/zlib",
|
||||
"application/octet-stream",
|
||||
"image/svg+xml",
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.ms-excel",
|
||||
"audio/x-wav",
|
||||
"application/x-git",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/zip",
|
||||
"image/jpeg",
|
||||
"audio/mpeg",
|
||||
"video/mp2t",
|
||||
"inode/x-empty",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"image/png",
|
||||
"image/vnd.microsoft.icon",
|
||||
"video/mp4",
|
||||
}
|
||||
mime_type = await get_mime_type(Path(filename), force_read=True)
|
||||
v = "text/" in mime_type or mime_type in pass_set
|
||||
if v:
|
||||
return True, mime_type
|
||||
|
||||
if mime_type not in denied_set:
|
||||
logger.info(mime_type)
|
||||
return False, mime_type
|
||||
|
|
|
|||
330
metagpt/utils/report.py
Normal file
330
metagpt/utils/report.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
import asyncio
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from urllib.parse import unquote, urlparse, urlunparse
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from aiohttp import ClientSession, UnixConnector
|
||||
from playwright.async_api import Page as AsyncPage
|
||||
from playwright.sync_api import Page as SyncPage
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from metagpt.const import METAGPT_REPORTER_DEFAULT_URL
|
||||
from metagpt.logs import create_llm_stream_queue, get_llm_stream_queue
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
try:
|
||||
import requests_unixsocket as requests
|
||||
except ImportError:
|
||||
import requests
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
CURRENT_ROLE: ContextVar["Role"] = ContextVar("role")
|
||||
|
||||
|
||||
class BlockType(str, Enum):
|
||||
"""Enumeration for different types of blocks."""
|
||||
|
||||
TERMINAL = "Terminal"
|
||||
TASK = "Task"
|
||||
BROWSER = "Browser"
|
||||
BROWSER_RT = "Browser-RT"
|
||||
EDITOR = "Editor"
|
||||
GALLERY = "Gallery"
|
||||
NOTEBOOK = "Notebook"
|
||||
DOCS = "Docs"
|
||||
THOUGHT = "Thought"
|
||||
|
||||
|
||||
END_MARKER_NAME = "end_marker"
|
||||
END_MARKER_VALUE = "\x18\x19\x1B\x18\n"
|
||||
|
||||
|
||||
class ResourceReporter(BaseModel):
|
||||
"""Base class for resource reporting."""
|
||||
|
||||
block: BlockType = Field(description="The type of block that is reporting the resource")
|
||||
uuid: UUID = Field(default_factory=uuid4, description="The unique identifier for the resource")
|
||||
enable_llm_stream: bool = Field(False, description="Indicates whether to connect to an LLM stream for reporting")
|
||||
callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent")
|
||||
_llm_task: Optional[asyncio.Task] = PrivateAttr(None)
|
||||
|
||||
def report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
"""Synchronously report resource observation data.
|
||||
|
||||
Args:
|
||||
value: The data to report.
|
||||
name: The type name of the data.
|
||||
"""
|
||||
return self._report(value, name, extra)
|
||||
|
||||
async def async_report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
"""Asynchronously report resource observation data.
|
||||
|
||||
Args:
|
||||
value: The data to report.
|
||||
name: The type name of the data.
|
||||
"""
|
||||
return await self._async_report(value, name, extra)
|
||||
|
||||
@classmethod
|
||||
def set_report_fn(cls, fn: Callable):
|
||||
"""Set the synchronous report function.
|
||||
|
||||
Args:
|
||||
fn: A callable function used for synchronous reporting. For example:
|
||||
|
||||
>>> def _report(self, value: Any, name: str):
|
||||
... print(value, name)
|
||||
|
||||
"""
|
||||
cls._report = fn
|
||||
|
||||
@classmethod
|
||||
def set_async_report_fn(cls, fn: Callable):
|
||||
"""Set the asynchronous report function.
|
||||
|
||||
Args:
|
||||
fn: A callable function used for asynchronous reporting. For example:
|
||||
|
||||
```python
|
||||
>>> async def _report(self, value: Any, name: str):
|
||||
... print(value, name)
|
||||
```
|
||||
"""
|
||||
cls._async_report = fn
|
||||
|
||||
def _report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
if not self.callback_url:
|
||||
return
|
||||
|
||||
data = self._format_data(value, name, extra)
|
||||
resp = requests.post(self.callback_url, json=data)
|
||||
resp.raise_for_status()
|
||||
return resp.text
|
||||
|
||||
async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None):
|
||||
if not self.callback_url:
|
||||
return
|
||||
|
||||
data = self._format_data(value, name, extra)
|
||||
url = self.callback_url
|
||||
_result = urlparse(url)
|
||||
sessiion_kwargs = {}
|
||||
if _result.scheme.endswith("+unix"):
|
||||
parsed_list = list(_result)
|
||||
parsed_list[0] = parsed_list[0][:-5]
|
||||
parsed_list[1] = "fake.org"
|
||||
url = urlunparse(parsed_list)
|
||||
sessiion_kwargs["connector"] = UnixConnector(path=unquote(_result.netloc))
|
||||
|
||||
async with ClientSession(**sessiion_kwargs) as client:
|
||||
async with client.post(url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.text()
|
||||
|
||||
def _format_data(self, value, name, extra):
|
||||
data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream"))
|
||||
if isinstance(value, BaseModel):
|
||||
value = value.model_dump(mode="json")
|
||||
elif isinstance(value, Path):
|
||||
value = str(value)
|
||||
|
||||
if name == "path":
|
||||
value = os.path.abspath(value)
|
||||
data["value"] = value
|
||||
data["name"] = name
|
||||
role = CURRENT_ROLE.get(None)
|
||||
if role:
|
||||
role_name = role.name
|
||||
else:
|
||||
role_name = os.environ.get("METAGPT_ROLE")
|
||||
data["role"] = role_name
|
||||
if extra:
|
||||
data["extra"] = extra
|
||||
return data
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the synchronous streaming callback context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
"""Exit the synchronous streaming callback context."""
|
||||
self.report(None, END_MARKER_NAME)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the asynchronous streaming callback context."""
|
||||
if self.enable_llm_stream:
|
||||
queue = create_llm_stream_queue()
|
||||
self._llm_task = asyncio.create_task(self._llm_stream_report(queue))
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, exc_tb):
|
||||
"""Exit the asynchronous streaming callback context."""
|
||||
if self.enable_llm_stream and exc_type != asyncio.CancelledError:
|
||||
await get_llm_stream_queue().put(None)
|
||||
await self._llm_task
|
||||
self._llm_task = None
|
||||
await self.async_report(None, END_MARKER_NAME)
|
||||
|
||||
async def _llm_stream_report(self, queue: asyncio.Queue):
|
||||
while True:
|
||||
data = await queue.get()
|
||||
if data is None:
|
||||
return
|
||||
await self.async_report(data, "content")
|
||||
|
||||
async def wait_llm_stream_report(self):
|
||||
"""Wait for the LLM stream report to complete."""
|
||||
queue = get_llm_stream_queue()
|
||||
while self._llm_task:
|
||||
if queue.empty():
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
class TerminalReporter(ResourceReporter):
|
||||
"""Terminal output callback for streaming reporting of command and output.
|
||||
|
||||
The terminal has state, and an agent can open multiple terminals and input different commands into them.
|
||||
To correctly display these states, each terminal should have its own unique ID, so in practice, each terminal
|
||||
should instantiate its own TerminalReporter object.
|
||||
"""
|
||||
|
||||
block: Literal[BlockType.TERMINAL] = BlockType.TERMINAL
|
||||
|
||||
def report(self, value: str, name: Literal["cmd", "output"]):
|
||||
"""Report terminal command or output synchronously."""
|
||||
return super().report(value, name)
|
||||
|
||||
async def async_report(self, value: str, name: Literal["cmd", "output"]):
|
||||
"""Report terminal command or output asynchronously."""
|
||||
return await super().async_report(value, name)
|
||||
|
||||
|
||||
class BrowserReporter(ResourceReporter):
|
||||
"""Browser output callback for streaming reporting of requested URL and page content.
|
||||
|
||||
The browser has state, so in practice, each browser should instantiate its own BrowserReporter object.
|
||||
"""
|
||||
|
||||
block: Literal[BlockType.BROWSER] = BlockType.BROWSER
|
||||
|
||||
def report(self, value: Union[str, SyncPage], name: Literal["url", "page"]):
|
||||
"""Report browser URL or page content synchronously."""
|
||||
if name == "page":
|
||||
value = {"page_url": value.url, "title": value.title(), "screenshot": str(value.screenshot())}
|
||||
return super().report(value, name)
|
||||
|
||||
async def async_report(self, value: Union[str, AsyncPage], name: Literal["url", "page"]):
|
||||
"""Report browser URL or page content asynchronously."""
|
||||
if name == "page":
|
||||
value = {"page_url": value.url, "title": await value.title(), "screenshot": str(await value.screenshot())}
|
||||
return await super().async_report(value, name)
|
||||
|
||||
|
||||
class ServerReporter(ResourceReporter):
|
||||
"""Callback for server deployment reporting."""
|
||||
|
||||
block: Literal[BlockType.BROWSER_RT] = BlockType.BROWSER_RT
|
||||
|
||||
def report(self, value: str, name: Literal["local_url"] = "local_url"):
|
||||
"""Report server deployment synchronously."""
|
||||
return super().report(value, name)
|
||||
|
||||
async def async_report(self, value: str, name: Literal["local_url"] = "local_url"):
|
||||
"""Report server deployment asynchronously."""
|
||||
return await super().async_report(value, name)
|
||||
|
||||
|
||||
class ObjectReporter(ResourceReporter):
|
||||
"""Callback for reporting complete object resources."""
|
||||
|
||||
def report(self, value: dict, name: Literal["object"] = "object"):
|
||||
"""Report object resource synchronously."""
|
||||
return super().report(value, name)
|
||||
|
||||
async def async_report(self, value: dict, name: Literal["object"] = "object"):
|
||||
"""Report object resource asynchronously."""
|
||||
return await super().async_report(value, name)
|
||||
|
||||
|
||||
class TaskReporter(ObjectReporter):
|
||||
"""Reporter for object resources to Task Block."""
|
||||
|
||||
block: Literal[BlockType.TASK] = BlockType.TASK
|
||||
|
||||
|
||||
class ThoughtReporter(ObjectReporter):
|
||||
"""Reporter for object resources to Task Block."""
|
||||
|
||||
block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT
|
||||
|
||||
|
||||
class FileReporter(ResourceReporter):
|
||||
"""File resource callback for reporting complete file paths.
|
||||
|
||||
There are two scenarios: if the file needs to be output in its entirety at once, use non-streaming callback;
|
||||
if the file can be partially output for display first, use streaming callback.
|
||||
"""
|
||||
|
||||
def report(
|
||||
self,
|
||||
value: Union[Path, dict, Any],
|
||||
name: Literal["path", "meta", "content"] = "path",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
"""Report file resource synchronously."""
|
||||
return super().report(value, name, extra)
|
||||
|
||||
async def async_report(
|
||||
self,
|
||||
value: Union[Path, dict, Any],
|
||||
name: Literal["path", "meta", "content"] = "path",
|
||||
extra: Optional[dict] = None,
|
||||
):
|
||||
"""Report file resource asynchronously."""
|
||||
return await super().async_report(value, name, extra)
|
||||
|
||||
|
||||
class NotebookReporter(FileReporter):
|
||||
"""Equivalent to FileReporter(block=BlockType.NOTEBOOK)."""
|
||||
|
||||
block: Literal[BlockType.NOTEBOOK] = BlockType.NOTEBOOK
|
||||
|
||||
|
||||
class DocsReporter(FileReporter):
|
||||
"""Equivalent to FileReporter(block=BlockType.DOCS)."""
|
||||
|
||||
block: Literal[BlockType.DOCS] = BlockType.DOCS
|
||||
|
||||
|
||||
class EditorReporter(FileReporter):
|
||||
"""Equivalent to FileReporter(block=BlockType.EDITOR)."""
|
||||
|
||||
block: Literal[BlockType.EDITOR] = BlockType.EDITOR
|
||||
|
||||
|
||||
class GalleryReporter(FileReporter):
|
||||
"""Image resource callback for reporting complete file paths.
|
||||
|
||||
Since images need to be complete before display, each callback is a complete file path. However, the Gallery
|
||||
needs to display the type of image and prompt, so if there is meta information, it should be reported in a
|
||||
streaming manner.
|
||||
"""
|
||||
|
||||
block: Literal[BlockType.GALLERY] = BlockType.GALLERY
|
||||
|
||||
def report(self, value: Union[dict, Path], name: Literal["meta", "path"] = "path"):
|
||||
"""Report image resource synchronously."""
|
||||
return super().report(value, name)
|
||||
|
||||
async def async_report(self, value: Union[dict, Path], name: Literal["meta", "path"] = "path"):
|
||||
"""Report image resource asynchronously."""
|
||||
return await super().async_report(value, name)
|
||||
|
|
@ -12,13 +12,11 @@ ref5: https://ai.google.dev/models/gemini
|
|||
"""
|
||||
import anthropic
|
||||
import tiktoken
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.ahttp_client import apost
|
||||
|
||||
TOKEN_COSTS = {
|
||||
"anthropic/claude-3.5-sonnet": {"prompt": 0.003, "completion": 0.015},
|
||||
"gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
|
||||
|
|
@ -77,6 +75,15 @@ TOKEN_COSTS = {
|
|||
"claude-3-7-sonnet-20250219": {"prompt": 0.003, "completion": 0.015},
|
||||
"yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003},
|
||||
"yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"openai/gpt-4": {"prompt": 0.03, "completion": 0.06}, # start, for openrouter
|
||||
"openai/gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"openai/gpt-4o": {"prompt": 0.005, "completion": 0.015},
|
||||
"openai/gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
|
||||
"openai/gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
|
||||
"openai/gpt-4o-mini-2024-07-18": {"prompt": 0.00015, "completion": 0.0006},
|
||||
"google/gemini-flash-1.5": {"prompt": 0.00025, "completion": 0.00075},
|
||||
"deepseek/deepseek-coder": {"prompt": 0.00014, "completion": 0.00028},
|
||||
"deepseek/deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, # end, for openrouter
|
||||
"yi-large": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"microsoft/wizardlm-2-8x22b": {"prompt": 0.00108, "completion": 0.00108}, # for openrouter, start
|
||||
"meta-llama/llama-3-70b-instruct": {"prompt": 0.008, "completion": 0.008},
|
||||
|
|
@ -283,6 +290,18 @@ TOKEN_MAX = {
|
|||
"claude-3-haiku-20240307": 200000,
|
||||
"yi-34b-chat-0205": 4000,
|
||||
"yi-34b-chat-200k": 200000,
|
||||
"openai/gpt-4": 8192, # start, for openrouter
|
||||
"openai/gpt-4-turbo": 128000,
|
||||
"openai/gpt-4o": 128000,
|
||||
"openai/gpt-4o-2024-05-13": 128000,
|
||||
"openai/gpt-4o-mini": 128000,
|
||||
"openai/gpt-4o-mini-2024-07-18": 128000,
|
||||
"google/gemini-flash-1.5": 2800000,
|
||||
"deepseek/deepseek-coder": 128000,
|
||||
"deepseek/deepseek-chat": 128000, # end, for openrouter
|
||||
"deepseek-chat": 128000,
|
||||
"deepseek-coder": 128000,
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Instruct": 32000, # siliconflow
|
||||
"yi-large": 16385,
|
||||
"microsoft/wizardlm-2-8x22b": 65536,
|
||||
"meta-llama/llama-3-70b-instruct": 8192,
|
||||
|
|
@ -294,8 +313,6 @@ TOKEN_MAX = {
|
|||
"anthropic/claude-3-opus": 200000,
|
||||
"anthropic/claude-3.5-sonnet": 200000,
|
||||
"google/gemini-pro-1.5": 4000000,
|
||||
"deepseek-chat": 32768,
|
||||
"deepseek-coder": 16385,
|
||||
"doubao-lite-4k-240515": 4000,
|
||||
"doubao-lite-32k-240515": 32000,
|
||||
"doubao-lite-128k-240515": 128000,
|
||||
|
|
@ -387,7 +404,7 @@ SPARK_TOKENS = {
|
|||
}
|
||||
|
||||
|
||||
def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
if "claude" in model:
|
||||
# rough estimation for models newer than claude-2.1
|
||||
|
|
@ -438,10 +455,10 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" == model:
|
||||
logger.info("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
|
||||
return count_input_tokens(messages, model="gpt-3.5-turbo-0125")
|
||||
return count_message_tokens(messages, model="gpt-3.5-turbo-0125")
|
||||
elif "gpt-4" == model:
|
||||
logger.info("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_input_tokens(messages, model="gpt-4-0613")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
elif "open-llm-model" == model:
|
||||
"""
|
||||
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
|
||||
|
|
@ -507,16 +524,4 @@ def get_max_completion_tokens(messages: list[dict], model: str, default: int) ->
|
|||
"""
|
||||
if model not in TOKEN_MAX:
|
||||
return default
|
||||
return TOKEN_MAX[model] - count_input_tokens(messages) - 1
|
||||
|
||||
|
||||
async def get_openrouter_tokens(chunk: ChatCompletionChunk) -> CompletionUsage:
|
||||
"""refs to https://openrouter.ai/docs#querying-cost-and-stats"""
|
||||
url = f"https://openrouter.ai/api/v1/generation?id={chunk.id}"
|
||||
resp = await apost(url=url, as_json=True)
|
||||
tokens_prompt = resp.get("tokens_prompt", 0)
|
||||
completion_tokens = resp.get("tokens_completion", 0)
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens=tokens_prompt, completion_tokens=completion_tokens, total_tokens=tokens_prompt + completion_tokens
|
||||
)
|
||||
return usage
|
||||
return TOKEN_MAX[model] - count_message_tokens(messages, model) - 1
|
||||
|
|
|
|||
|
|
@ -27,14 +27,15 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
from metagpt.tools.libs.shell import shell_execute
|
||||
|
||||
def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
|
||||
async def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str:
|
||||
"""
|
||||
Recursively traverses the directory structure and prints it out in a tree-like format.
|
||||
|
||||
|
|
@ -80,7 +81,7 @@ def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = Fal
|
|||
"""
|
||||
root = Path(root).resolve()
|
||||
if run_command:
|
||||
return _execute_tree(root, gitignore)
|
||||
return await _execute_tree(root, gitignore)
|
||||
|
||||
git_ignore_rules = parse_gitignore(gitignore) if gitignore else None
|
||||
dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)}
|
||||
|
|
@ -129,12 +130,7 @@ def _add_line(rows: List[str]) -> List[str]:
|
|||
return rows
|
||||
|
||||
|
||||
def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
async def _execute_tree(root: Path, gitignore: str | Path) -> str:
|
||||
args = ["--gitfile", str(gitignore)] if gitignore else []
|
||||
try:
|
||||
result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True)
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"tree exits with code {result.returncode}")
|
||||
return result.stdout
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise e
|
||||
stdout, _, _ = await shell_execute(["tree"] + args + [str(root)])
|
||||
return stdout
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue