FIx Format and Some bugs in android_assistant.py

This commit is contained in:
didi 2024-03-04 16:47:27 +08:00
parent f58012611c
commit 138bb6e63d
16 changed files with 510 additions and 177 deletions

View file

@ -24,13 +24,16 @@ import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, List, Tuple, Union, Callable
from typing import Any, Callable, List, Tuple, Union
import aiofiles
import loguru
import requests
from PIL import Image
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from tenacity import RetryCallState, RetryError, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
from metagpt.logs import logger
@ -214,7 +217,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index: end_index + 1]
structure_text = text[start_index : end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval
@ -358,6 +361,31 @@ def parse_recipient(text):
return ""
def create_func_call_config(func_schema: dict) -> dict:
"""Create new function call config"""
tools = [{"type": "function", "function": func_schema}]
tool_choice = {"type": "function", "function": {"name": func_schema["name"]}}
return {
"tools": tools,
"tool_choice": tool_choice,
}
def remove_comments(code_str: str) -> str:
"""Remove comments from code."""
pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"
def replace_func(match):
if match.group(2) is not None:
return ""
else:
return match.group(1)
clean_code = re.sub(pattern, replace_func, code_str, flags=re.MULTILINE)
clean_code = os.linesep.join([s.rstrip() for s in clean_code.splitlines() if s.strip()])
return clean_code
def get_class_name(cls) -> str:
"""Return class name"""
return f"{cls.__module__}.{cls.__name__}"
@ -466,13 +494,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
return data
def write_json_file(json_file: str, data: list, encoding=None):
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
@ -538,7 +566,7 @@ def role_raise_decorator(func):
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
except Exception as e:
if self.latest_observed_msg:
logger.warning(
"There is a exception in role's execution, in order to resume, "
@ -547,6 +575,12 @@ def role_raise_decorator(func):
# remove role newest observed msg to make it observed again
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
if isinstance(e, RetryError):
last_error = e.last_attempt._exception
name = any_to_str(last_error)
if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name):
raise last_error
raise Exception(format_trackback_info(limit=None))
return wrapper
@ -606,6 +640,39 @@ def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
with open(str(image_path), "rb") as image_file:
return base64.b64encode(image_file.read()).decode(encoding)
def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) -> list[str]:
"""load mincraft skill from js files"""
if not skills_dir:
skills_dir = Path(__file__).parent.absolute()
if skill_names is None:
skill_names = [skill[:-3] for skill in os.listdir(f"{skills_dir}") if skill.endswith(".js")]
skills = [skills_dir.joinpath(f"{skill_name}.js").read_text() for skill_name in skill_names]
return skills
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
"""encode image from file or PIL.Image into base64"""
if isinstance(image_path_or_pil, Image.Image):
buffer = BytesIO()
image_path_or_pil.save(buffer, format="JPEG")
bytes_data = buffer.getvalue()
else:
if not image_path_or_pil.exists():
raise FileNotFoundError(f"{image_path_or_pil} not exists")
with open(str(image_path_or_pil), "rb") as image_file:
bytes_data = image_file.read()
return base64.b64encode(bytes_data).decode(encoding)
def decode_image(img_url_or_b64: str) -> Image:
"""decode image from url or base64 into PIL.Image"""
if img_url_or_b64.startswith("http"):
# image http(s) url
resp = requests.get(img_url_or_b64)
img = Image.open(BytesIO(resp.content))
else:
# image b64_json
b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64)
img_data = BytesIO(base64.b64decode(b64_data))
img = Image.open(img_data)
return img