mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 10:26:32 +02:00
Merge main branch
This commit is contained in:
commit
24e617b362
325 changed files with 11290 additions and 3760 deletions
|
|
@ -13,7 +13,7 @@ from metagpt.actions.add_requirement import UserRequirement
|
|||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.actions.design_api_review import DesignReview
|
||||
from metagpt.actions.project_management import AssignTasks, WriteTasks
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.actions.search_and_summarize import SearchAndSummarize
|
||||
|
|
@ -38,7 +38,6 @@ class ActionType(Enum):
|
|||
RUN_CODE = RunCode
|
||||
DEBUG_ERROR = DebugError
|
||||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
|
|
|
|||
|
|
@ -8,61 +8,45 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import (
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
RunCodeContext,
|
||||
SerializationMixin,
|
||||
TestingContext,
|
||||
)
|
||||
|
||||
action_subclass_registry = {}
|
||||
|
||||
class Action(SerializationMixin, is_polymorphic_base=True):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
class Action(BaseModel):
|
||||
name: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
prefix = "" # aask*时会加上prefix,作为system_message
|
||||
desc = "" # for skill manager
|
||||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
# builtin variables
|
||||
builtin_class_name: str = ""
|
||||
@model_validator(mode="before")
|
||||
def set_name_if_empty(cls, values):
|
||||
if "name" not in values or not values["name"]:
|
||||
values["name"] = cls.__name__
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init_with_instruction(self, instruction: str):
|
||||
"""Initialize action with instruction"""
|
||||
self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw")
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# deserialize child classes dynamically for inherited `action`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
if "instruction" in kwargs:
|
||||
self.__init_with_instruction(kwargs["instruction"])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
action_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
obj_dict = super().dict(*args, **kwargs)
|
||||
if "llm" in obj_dict:
|
||||
obj_dict.pop("llm")
|
||||
return obj_dict
|
||||
@model_validator(mode="before")
|
||||
def _init_with_instruction(cls, values):
|
||||
if "instruction" in values:
|
||||
name = values["name"]
|
||||
i = values["instruction"]
|
||||
values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw")
|
||||
return values
|
||||
|
||||
def set_prefix(self, prefix):
|
||||
"""Set prefix for later usage"""
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model, root_validator, validator
|
||||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.llm import BaseGPTAPI
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
|
@ -59,7 +60,7 @@ class ActionNode:
|
|||
|
||||
# Action Context
|
||||
context: str # all the context, including all necessary info
|
||||
llm: BaseGPTAPI # LLM with aask interface
|
||||
llm: BaseLLM # LLM with aask interface
|
||||
children: dict[str, "ActionNode"]
|
||||
|
||||
# Action Input
|
||||
|
|
@ -116,50 +117,48 @@ class ActionNode:
|
|||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items()}
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
||||
def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]:
|
||||
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get key: type mapping under mode"""
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
return self.get_children_mapping()
|
||||
return self.get_self_mapping()
|
||||
return self.get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
new_class = create_model(class_name, **mapping)
|
||||
|
||||
@validator("*", allow_reuse=True)
|
||||
def check_name(v, field):
|
||||
if field.name not in mapping.keys():
|
||||
raise ValueError(f"Unrecognized block: {field.name}")
|
||||
return v
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def check_missing_fields(values):
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing fields: {missing_fields}")
|
||||
|
||||
unrecognized_fields = set(values.keys()) - required_fields
|
||||
if unrecognized_fields:
|
||||
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
|
||||
return values
|
||||
|
||||
new_class.__validator_check_name = classmethod(check_name)
|
||||
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
|
||||
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
return new_class
|
||||
|
||||
def create_children_class(self):
|
||||
def create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
mapping = self.get_children_mapping()
|
||||
mapping = self.get_children_mapping(exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def to_dict(self, format_func=None, mode="auto") -> Dict:
|
||||
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
|
||||
# 如果没有提供格式化函数,使用默认的格式化方式
|
||||
|
|
@ -179,7 +178,10 @@ class ActionNode:
|
|||
return node_dict
|
||||
|
||||
# 遍历子节点并递归调用 to_dict 方法
|
||||
exclude = exclude or []
|
||||
for _, child_node in self.children.items():
|
||||
if child_node.key in exclude:
|
||||
continue
|
||||
node_dict.update(child_node.to_dict(format_func))
|
||||
|
||||
return node_dict
|
||||
|
|
@ -200,25 +202,25 @@ class ActionNode:
|
|||
else: # markdown
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode)
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
text = self.compile_to(nodes, schema, kv_sep)
|
||||
return self.tagging(text, schema, tag)
|
||||
|
||||
def compile_instruction(self, schema="markdown", mode="children", tag="") -> str:
|
||||
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown template with all/root/children nodes"""
|
||||
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ")
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
|
||||
|
||||
def compile_example(self, schema="json", mode="children", tag="") -> str:
|
||||
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown examples with all/root/children nodes"""
|
||||
|
||||
# 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example
|
||||
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str
|
||||
format_func = lambda i: i.example
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n")
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
|
||||
|
||||
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str:
|
||||
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
|
||||
"""
|
||||
mode: all/root/children
|
||||
mode="children": 编译所有子节点为一个统一模板,包括instruction与example
|
||||
|
|
@ -234,8 +236,8 @@ class ActionNode:
|
|||
|
||||
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
|
||||
# compile example暂时不支持markdown
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode)
|
||||
example = self.compile_example(schema=schema, tag=TAG, mode=mode)
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
|
||||
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
|
||||
# nodes = ", ".join(self.to_dict(mode=mode).keys())
|
||||
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
|
||||
constraint = "\n".join(constraints)
|
||||
|
|
@ -260,14 +262,17 @@ class ActionNode:
|
|||
output_data_mapping: dict,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
schema="markdown", # compatible to original format
|
||||
timeout=CONFIG.timeout,
|
||||
) -> (str, BaseModel):
|
||||
"""Use ActionOutput to wrap the output of aask"""
|
||||
content = await self.llm.aask(prompt, system_msgs)
|
||||
content = await self.llm.aask(prompt, system_msgs, timeout=timeout)
|
||||
logger.debug(f"llm raw output:\n{content}")
|
||||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
if schema == "json":
|
||||
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key=f"[/{TAG}]")
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
else: # using markdown parser
|
||||
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
|
||||
|
||||
|
|
@ -276,7 +281,7 @@ class ActionNode:
|
|||
return content, instruct_content
|
||||
|
||||
def get(self, key):
|
||||
return self.instruct_content.dict()[key]
|
||||
return self.instruct_content.model_dump()[key]
|
||||
|
||||
def set_recursive(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
|
@ -289,13 +294,13 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode)
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode)
|
||||
mapping = self.get_mapping(mode, exclude=exclude)
|
||||
class_name = f"{self.key}_AN"
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
|
||||
self.content = content
|
||||
self.instruct_content = scontent
|
||||
else:
|
||||
|
|
@ -304,7 +309,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
@ -320,6 +325,8 @@ class ActionNode:
|
|||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
:param timeout: Timeout for llm invocation.
|
||||
:param exclude: The keys of ActionNode to exclude.
|
||||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
|
|
@ -328,27 +335,15 @@ class ActionNode:
|
|||
schema = self.schema
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode)
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
elif strgy == "complex":
|
||||
# 这里隐式假设了拥有children
|
||||
tmp = {}
|
||||
for _, i in self.children.items():
|
||||
child = await i.simple_fill(schema=schema, mode=mode)
|
||||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
|
||||
def action_node_example():
|
||||
node = ActionNode(key="key-0", expected_type=str, instruction="instruction-a", example="example-b")
|
||||
|
||||
logger.info(node.compile(context="123", schema="raw", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="json", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="markdown", mode="auto"))
|
||||
logger.info(node.to_dict())
|
||||
logger.info(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
action_node_example()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,3 @@ from metagpt.actions import Action
|
|||
|
||||
class UserRequirement(Action):
|
||||
"""User Requirement without any implementation details"""
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,70 +0,0 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
||||
CLONE_PROMPT = """
|
||||
*context*
|
||||
Please convert the function code ```{source_code}``` into the the function format: ```{template_func}```.
|
||||
*Please Write code based on the following list and context*
|
||||
1. Write code start with ```, and end with ```.
|
||||
2. Please implement it in one function if possible, except for import statements. for exmaple:
|
||||
```python
|
||||
import pandas as pd
|
||||
def run(*args) -> pd.DataFrame:
|
||||
...
|
||||
```
|
||||
3. Do not use public member functions that do not exist in your design.
|
||||
4. The output function name, input parameters and return value must be the same as ```{template_func}```.
|
||||
5. Make sure the results before and after the code conversion are required to be exactly the same.
|
||||
6. Don't repeat my context in your replies.
|
||||
7. Return full results, for example, if the return value has df.head(), please return df.
|
||||
8. If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ...
|
||||
"""
|
||||
|
||||
|
||||
class CloneFunction(WriteCode):
|
||||
name: str = "CloneFunction"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
def _save(self, code_path, code):
|
||||
if isinstance(code_path, str):
|
||||
code_path = Path(code_path)
|
||||
code_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
code_path.write_text(code)
|
||||
logger.info(f"Saving Code to {code_path}")
|
||||
|
||||
async def run(self, template_func: str, source_code: str) -> str:
|
||||
"""将source_code转换成template_func一样的入参和返回类型"""
|
||||
prompt = CLONE_PROMPT.format(source_code=source_code, template_func=template_func)
|
||||
logger.info(f"query for CloneFunction: \n {prompt}")
|
||||
code = await self.write_code(prompt)
|
||||
logger.info(f"CloneFunction code is \n {highlight(code)}")
|
||||
return code
|
||||
|
||||
|
||||
def run_function_code(func_code: str, func_name: str, *args, **kwargs):
|
||||
"""Run function code from string code."""
|
||||
try:
|
||||
locals_ = {}
|
||||
exec(func_code, locals_)
|
||||
func = locals_[func_name]
|
||||
return func(*args, **kwargs), ""
|
||||
except Exception:
|
||||
return "", traceback.format_exc()
|
||||
|
||||
|
||||
def run_function_script(code_script_path: str, func_name: str, *args, **kwargs):
|
||||
"""Run function code from script."""
|
||||
if isinstance(code_script_path, str):
|
||||
code_path = Path(code_script_path)
|
||||
code = code_path.read_text(encoding="utf-8")
|
||||
return run_function_code(code, func_name, *args, **kwargs)
|
||||
|
|
@ -15,7 +15,6 @@ from pydantic import Field
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -52,7 +51,6 @@ Now you should start rewriting the code:
|
|||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.design_api_an import DESIGN_API_NODE, REFINED_DESIGN_NODES
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -25,9 +23,7 @@ from metagpt.const import (
|
|||
SYSTEM_DESIGN_FILE_REPO,
|
||||
SYSTEM_DESIGN_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
|
@ -44,7 +40,6 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
desc: str = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
|
|
@ -52,10 +47,10 @@ class WriteDesign(Action):
|
|||
)
|
||||
|
||||
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
|
||||
# Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
changed_prds = prds_file_repo.changed_files
|
||||
# Use `git diff` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
|
||||
# changes.
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
changed_system_designs = system_design_file_repo.changed_files
|
||||
|
|
@ -79,7 +74,7 @@ class WriteDesign(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/system_designs/` are processed before sending the publish message,
|
||||
# leaving room for global optimization in subsequent steps.
|
||||
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)
|
||||
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
|
||||
|
||||
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
|
|
@ -88,7 +83,7 @@ class WriteDesign(Action):
|
|||
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await REFINED_DESIGN_NODES.fill(context=context, llm=self.llm, schema=schema)
|
||||
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
return system_design_doc
|
||||
|
||||
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
|
||||
|
|
@ -99,7 +94,7 @@ class WriteDesign(Action):
|
|||
doc = Document(
|
||||
root_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
filename=filename,
|
||||
content=system_design.instruct_content.json(ensure_ascii=False),
|
||||
content=system_design.instruct_content.model_dump_json(),
|
||||
)
|
||||
else:
|
||||
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC1_REFINE, MMC2, MMC2_REFINE
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -157,14 +156,3 @@ REFINE_NODES = [
|
|||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
INCREMENTAL_DESIGN_NODES = ActionNode.from_children("Incremental_Design_API", INC_NODES)
|
||||
REFINED_DESIGN_NODES = ActionNode.from_children("Refined_Design_API", REFINE_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_DESIGN_NODES.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -8,17 +8,12 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
|
||||
|
||||
class DesignReview(Action):
|
||||
name: str = "DesignReview"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, prd, api_design):
|
||||
prompt = (
|
||||
|
|
|
|||
|
|
@ -6,18 +6,14 @@
|
|||
@File : execute_task.py
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class ExecuteTask(Action):
|
||||
name: str = "ExecuteTask"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -11,6 +11,3 @@ class FixBug(Action):
|
|||
"""Fix bug action without any implementation details"""
|
||||
|
||||
name: str = "FixBug"
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import (
|
|||
EXTRACT_OCR_MAIN_INFO_PROMPT,
|
||||
REPLY_OCR_QUESTION_PROMPT,
|
||||
)
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.file import File
|
||||
|
||||
|
|
@ -42,7 +42,6 @@ class InvoiceOCR(Action):
|
|||
|
||||
name: str = "InvoiceOCR"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
|
|
@ -132,7 +131,7 @@ class GenerateTable(Action):
|
|||
|
||||
name: str = "GenerateTable"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
|
|
@ -177,7 +176,7 @@ class ReplyQuestion(Action):
|
|||
|
||||
name: str = "ReplyQuestion"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
|
||||
|
|
|
|||
|
|
@ -11,13 +11,9 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -28,17 +24,18 @@ class PrepareDocuments(Action):
|
|||
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
path = CONFIG.project_path
|
||||
if not path:
|
||||
if not CONFIG.project_path:
|
||||
name = CONFIG.project_name or FileRepository.new_filename()
|
||||
path = Path(CONFIG.workspace_path) / name
|
||||
|
||||
if Path(path).exists() and not CONFIG.inc:
|
||||
else:
|
||||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.project_path = path
|
||||
CONFIG.project_name = path.name
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODES
|
||||
|
|
@ -25,9 +23,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TASK_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -43,7 +39,6 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
|
@ -73,7 +68,7 @@ class WriteTasks(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for
|
||||
# global optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.json(), instruct_content=change_files)
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
|
||||
system_design_doc = await system_design_file_repo.get(filename)
|
||||
|
|
@ -83,7 +78,7 @@ class WriteTasks(Action):
|
|||
else:
|
||||
rsp = await self._run_new_tasks(context=system_design_doc.content)
|
||||
task_doc = Document(
|
||||
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.json(ensure_ascii=False)
|
||||
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
|
||||
)
|
||||
await tasks_file_repo.save(
|
||||
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
|
||||
|
|
@ -94,15 +89,12 @@ class WriteTasks(Action):
|
|||
|
||||
async def _run_new_tasks(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
# prompt_template, format_example = get_template(templates, format)
|
||||
# prompt = prompt_template.format(context=context, format_example=format_example)
|
||||
# rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING, format=format)
|
||||
return node
|
||||
|
||||
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
|
||||
node = await REFINED_PM_NODES.fill(context, self.llm, schema)
|
||||
task_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -123,9 +115,3 @@ class WriteTasks(Action):
|
|||
@staticmethod
|
||||
async def _save_pdf(task_doc):
|
||||
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
|
||||
|
||||
|
||||
class AssignTasks(Action):
|
||||
async def run(self, *args, **kwargs):
|
||||
# Here you should implement the actual action
|
||||
pass
|
||||
|
|
|
|||
68
metagpt/actions/rebuild_class_view.py
Normal file
68
metagpt/actions/rebuild_class_view.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view.py
|
||||
@Desc : Rebuild class view info
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
def __init__(self, name="", context=None, llm=None):
|
||||
super().__init__(name=name, context=context, llm=llm)
|
||||
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=self.context)
|
||||
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
symbols = repo_parser.generate_symbols() # use ast
|
||||
for file_info in symbols:
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_view(graph_db=graph_db)
|
||||
await self._save(graph_db=graph_db)
|
||||
|
||||
async def _create_mermaid_class_view(self, graph_db):
|
||||
pass
|
||||
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
# if not dataset:
|
||||
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
|
||||
# return
|
||||
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
|
||||
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
|
||||
# code_type = ""
|
||||
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
|
||||
# for spo in dataset:
|
||||
# if spo.object_ in ["javascript", "python"]:
|
||||
# code_type = spo.object_
|
||||
# break
|
||||
|
||||
# try:
|
||||
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
|
||||
# class_view = node.instruct_content.model_dump()["Class View"]
|
||||
# except Exception as e:
|
||||
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
|
||||
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
|
||||
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
|
||||
|
||||
async def _save(self, graph_db):
|
||||
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
|
||||
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
all_class_view = []
|
||||
for spo in dataset:
|
||||
title = f"---\ntitle: {spo.subject}\n---\n"
|
||||
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
|
||||
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
|
||||
all_class_view.append(spo.object_)
|
||||
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
|
||||
33
metagpt/actions/rebuild_class_view_an.py
Normal file
33
metagpt/actions/rebuild_class_view_an.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view_an.py
|
||||
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
|
||||
"""
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
CLASS_SOURCE_CODE_BLOCK = ActionNode(
|
||||
key="Class View",
|
||||
expected_type=str,
|
||||
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
|
||||
example="""
|
||||
classDiagram
|
||||
class A {
|
||||
-int x
|
||||
+int y
|
||||
-int speed
|
||||
-int direction
|
||||
+__init__(x: int, y: int, speed: int, direction: int)
|
||||
+change_direction(new_direction: int) None
|
||||
+move() None
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
REBUILD_CLASS_VIEW_NODES = [
|
||||
CLASS_SOURCE_CODE_BLOCK,
|
||||
]
|
||||
|
||||
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)
|
||||
|
|
@ -11,7 +11,7 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
|
@ -82,8 +82,8 @@ class CollectLinks(Action):
|
|||
|
||||
name: str = "CollectLinks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
rank_func: Optional[Callable[[list[str]], None]] = None
|
||||
|
||||
|
|
@ -129,7 +129,8 @@ class CollectLinks(Action):
|
|||
if len(remove) == 0:
|
||||
break
|
||||
|
||||
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
|
||||
model_name = CONFIG.get_model_name(CONFIG.get_default_llm_provider_enum())
|
||||
prompt = reduce_message_length(gen_msg(), model_name, system_text, CONFIG.max_tokens_rsp)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
|
|
@ -177,7 +178,7 @@ class WebBrowseAndSummarize(Action):
|
|||
|
||||
name: str = "WebBrowseAndSummarize"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
|
|
@ -248,7 +249,7 @@ class ConductResearch(Action):
|
|||
|
||||
name: str = "ConductResearch"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -79,14 +78,15 @@ standard errors:
|
|||
class RunCode(Action):
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def run_text(cls, code) -> Tuple[str, str]:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
try:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
except Exception as e:
|
||||
return "", str(e)
|
||||
return namespace.get("result", ""), ""
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -8,13 +8,11 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import Field, root_validator
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -105,18 +103,18 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
result: str = ""
|
||||
|
||||
result = ""
|
||||
|
||||
@root_validator
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_engine_and_run_func(cls, values):
|
||||
engine = values.get("engine")
|
||||
search_func = values.get("search_func")
|
||||
|
|
|
|||
111
metagpt/actions/skill_action.py
Normal file
111
metagpt/actions/skill_action.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : skill_action.py
|
||||
@Desc : Call learned skill
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.learn.skill_loader import Skill
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
# TOTEST
|
||||
class ArgumentsParingAction(Action):
|
||||
skill: Skill
|
||||
ask: str
|
||||
rsp: Optional[Message] = None
|
||||
args: Optional[Dict] = None
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = "You are a function parser. You can convert spoken words into function parameters.\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += f"{self.skill.name} function parameters description:\n"
|
||||
for k, v in self.skill.arguments.items():
|
||||
prompt += f"parameter `{k}`: {v}\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += "Examples:\n"
|
||||
for e in self.skill.examples:
|
||||
prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n"
|
||||
prompt += "\n---\n"
|
||||
prompt += (
|
||||
f"\nRefer to the `{self.skill.name}` function description, and fill in the function parameters according "
|
||||
'to the example "I want you to do xx" in the Examples section.'
|
||||
f"\nNow I want you to do `{self.ask}`, return function parameters in Examples format above, brief and "
|
||||
"clear."
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
prompt = self.prompt
|
||||
rsp = await self.llm.aask(msg=prompt, system_msgs=[])
|
||||
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
|
||||
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
|
||||
self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
@staticmethod
|
||||
def parse_arguments(skill_name, txt) -> dict:
|
||||
prefix = skill_name + "("
|
||||
if prefix not in txt:
|
||||
logger.error(f"{skill_name} not in {txt}")
|
||||
return None
|
||||
if ")" not in txt:
|
||||
logger.error(f"')' not in {txt}")
|
||||
return None
|
||||
begin_ix = txt.find(prefix)
|
||||
end_ix = txt.rfind(")")
|
||||
args_txt = txt[begin_ix + len(prefix) : end_ix]
|
||||
logger.info(args_txt)
|
||||
fake_expression = f"dict({args_txt})"
|
||||
parsed_expression = ast.parse(fake_expression, mode="eval")
|
||||
args = {}
|
||||
for keyword in parsed_expression.body.keywords:
|
||||
key = keyword.arg
|
||||
value = ast.literal_eval(keyword.value)
|
||||
args[key] = value
|
||||
return args
|
||||
|
||||
|
||||
class SkillAction(Action):
|
||||
skill: Skill
|
||||
args: Dict
|
||||
rsp: Optional[Message] = None
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
"""Run action"""
|
||||
options = deepcopy(kwargs)
|
||||
if self.args:
|
||||
for k in self.args.keys():
|
||||
if k in options:
|
||||
options.pop(k)
|
||||
try:
|
||||
rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, traceback:{traceback.format_exc()}")
|
||||
self.rsp = Message(content=f"Error: {e}", role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
@staticmethod
|
||||
async def find_and_call_function(function_name, args, **kwargs) -> str:
|
||||
try:
|
||||
module = importlib.import_module("metagpt.learn")
|
||||
function = getattr(module, function_name)
|
||||
# Invoke function and return result
|
||||
result = await function(**args, **kwargs)
|
||||
return result
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
logger.error(f"{function_name} not found")
|
||||
raise ValueError(f"{function_name} not found")
|
||||
|
|
@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -91,10 +90,10 @@ flowchart TB
|
|||
"""
|
||||
|
||||
|
||||
# TOTEST
|
||||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
|
|||
163
metagpt/actions/talk_action.py
Normal file
163
metagpt/actions/talk_action.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : talk_action.py
|
||||
@Desc : Act as it’s a talk
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
# TOTEST
|
||||
class TalkAction(Action):
|
||||
context: str
|
||||
history_summary: str = ""
|
||||
knowledge: str = ""
|
||||
rsp: Optional[Message] = None
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
prompt = ""
|
||||
if CONFIG.agent_description:
|
||||
prompt = (
|
||||
f"You are {CONFIG.agent_description}. Your responses should align with the role-play agreement, "
|
||||
f"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
f"decline to answer without revealing your AI nature to preserve the character's image.\n\n"
|
||||
)
|
||||
prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else ""
|
||||
prompt += f"{self.history_summary}\n\n"
|
||||
prompt += (
|
||||
"If the information is insufficient, you can search in the historical conversation or knowledge above.\n"
|
||||
)
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
prompt += (
|
||||
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n "
|
||||
f"{self.context}"
|
||||
)
|
||||
logger.debug(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def prompt_gpt4(self):
|
||||
kvs = {
|
||||
"{role}": CONFIG.agent_description or "",
|
||||
"{history}": self.history_summary or "",
|
||||
"{knowledge}": self.knowledge or "",
|
||||
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
|
||||
"{ask}": self.context,
|
||||
}
|
||||
prompt = TalkActionPrompt.FORMATION_LOOSE
|
||||
for k, v in kvs.items():
|
||||
prompt = prompt.replace(k, v)
|
||||
logger.info(f"PROMPT: {prompt}")
|
||||
return prompt
|
||||
|
||||
# async def run_old(self, *args, **kwargs) -> ActionOutput:
|
||||
# prompt = self.prompt
|
||||
# rsp = await self.llm.aask(msg=prompt, system_msgs=[])
|
||||
# logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n")
|
||||
# self._rsp = ActionOutput(content=rsp)
|
||||
# return self._rsp
|
||||
|
||||
@property
|
||||
def aask_args(self):
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
system_msgs = [
|
||||
f"You are {CONFIG.agent_description}.",
|
||||
"Your responses should align with the role-play agreement, "
|
||||
"maintaining the character's persona and habits. When faced with unrelated questions, playfully "
|
||||
"decline to answer without revealing your AI nature to preserve the character's image.",
|
||||
"If the information is insufficient, you can search in the context or knowledge.",
|
||||
f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.",
|
||||
]
|
||||
format_msgs = []
|
||||
if self.knowledge:
|
||||
format_msgs.append({"role": "assistant", "content": self.knowledge})
|
||||
if self.history_summary:
|
||||
format_msgs.append({"role": "assistant", "content": self.history_summary})
|
||||
return self.context, format_msgs, system_msgs
|
||||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
msg, format_msgs, system_msgs = self.aask_args
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
|
||||
class TalkActionPrompt:
|
||||
FORMATION = """Formation: "Capacity and role" defines the role you are currently playing;
|
||||
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
|
||||
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
|
||||
"Statement" defines the work detail you need to complete at this stage;
|
||||
"[ASK_BEGIN]" and [ASK_END] tags enclose the questions;
|
||||
"Constraint" defines the conditions that your responses must comply with.
|
||||
"Personality" defines your language style。
|
||||
"Insight" provides a deeper understanding of the characters' inner traits.
|
||||
"Initial" defines the initial setup of a character.
|
||||
|
||||
Capacity and role: {role}
|
||||
Statement: Your responses should align with the role-play agreement, maintaining the
|
||||
character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing
|
||||
your AI nature to preserve the character's image.
|
||||
|
||||
[HISTORY_BEGIN]
|
||||
|
||||
{history}
|
||||
|
||||
[HISTORY_END]
|
||||
|
||||
[KNOWLEDGE_BEGIN]
|
||||
|
||||
{knowledge}
|
||||
|
||||
[KNOWLEDGE_END]
|
||||
|
||||
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
|
||||
Statement: Unless you are a language professional, answer the following questions strictly in {language}
|
||||
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
|
||||
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
|
||||
|
||||
|
||||
{ask}
|
||||
"""
|
||||
|
||||
FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing;
|
||||
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
|
||||
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
|
||||
"Statement" defines the work detail you need to complete at this stage;
|
||||
"Constraint" defines the conditions that your responses must comply with.
|
||||
"Personality" defines your language style。
|
||||
"Insight" provides a deeper understanding of the characters' inner traits.
|
||||
"Initial" defines the initial setup of a character.
|
||||
|
||||
Capacity and role: {role}
|
||||
Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions
|
||||
, playfully decline to answer without revealing your AI nature to preserve the character's image.
|
||||
|
||||
[HISTORY_BEGIN]
|
||||
|
||||
{history}
|
||||
|
||||
[HISTORY_END]
|
||||
|
||||
[KNOWLEDGE_BEGIN]
|
||||
|
||||
{knowledge}
|
||||
|
||||
[KNOWLEDGE_END]
|
||||
|
||||
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
|
||||
Statement: Unless you are a language professional, answer the following questions strictly in {language}
|
||||
, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]"
|
||||
, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses.
|
||||
|
||||
|
||||
{ask}
|
||||
"""
|
||||
|
|
@ -31,9 +31,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -92,7 +90,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -123,7 +121,6 @@ REWRITE_CODE_TEMPLATE = """
|
|||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
|
|||
|
|
@ -21,15 +21,14 @@ Example:
|
|||
This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using
|
||||
the specified docstring style and adds them to the code.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.common import OutputParser, aread, awrite
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
PYTHON_DOCSTRING_SYSTEM = """### Requirements
|
||||
|
|
@ -163,7 +162,6 @@ class WriteDocstring(Action):
|
|||
|
||||
desc: str = "Write docstring for code."
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -187,6 +185,16 @@ class WriteDocstring(Action):
|
|||
documented_code = OutputParser.parse_python_code(documented_code)
|
||||
return merge_docstring(code, documented_code)
|
||||
|
||||
@staticmethod
|
||||
async def write_docstring(
|
||||
filename: str | Path, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"
|
||||
) -> str:
|
||||
data = await aread(str(filename))
|
||||
code = await WriteDocstring().run(data, style=style)
|
||||
if overwrite:
|
||||
await awrite(filename, code)
|
||||
return code
|
||||
|
||||
|
||||
def _simplify_python_code(code: str) -> None:
|
||||
"""Simplifies the given Python code by removing expressions and the last if statement.
|
||||
|
|
@ -207,13 +215,4 @@ def _simplify_python_code(code: str) -> None:
|
|||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def run(filename: str, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google"):
|
||||
with open(filename) as f:
|
||||
code = f.read()
|
||||
code = await WriteDocstring().run(code, style=style)
|
||||
if overwrite:
|
||||
with open(filename, "w") as f:
|
||||
f.write(code)
|
||||
return code
|
||||
|
||||
fire.Fire(run)
|
||||
fire.Fire(WriteDocstring.write_docstring)
|
||||
|
|
|
|||
|
|
@ -17,12 +17,11 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.write_prd_an import (
|
||||
PROJECT_NAME,
|
||||
REFINE_PRD_NODE,
|
||||
REFINE_PRD_TEMPLATE,
|
||||
WP_IS_RELATIVE_NODE,
|
||||
|
|
@ -38,9 +37,7 @@ from metagpt.const import (
|
|||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import BugFixContext, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -67,9 +64,8 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WritePRD(Action):
|
||||
name: str = ""
|
||||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
|
|
@ -81,7 +77,7 @@ class WritePRD(Action):
|
|||
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.json(),
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
|
|
@ -113,7 +109,7 @@ class WritePRD(Action):
|
|||
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
|
||||
# 'publish' message to transition the workflow to the next stage. This design allows room for global
|
||||
# optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.json(), instruct_content=change_files)
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
|
||||
# sas = SearchAndSummarize()
|
||||
|
|
@ -125,7 +121,8 @@ class WritePRD(Action):
|
|||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
|
||||
|
|
@ -137,15 +134,13 @@ class WritePRD(Action):
|
|||
async def _merge(self, new_requirement_doc, prd_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
prompt = REFINE_PRD_TEMPLATE.format(
|
||||
requirements=new_requirement_doc.content,
|
||||
old_prd=prd_doc.content,
|
||||
project_name=project_name,
|
||||
project_name=CONFIG.project_name,
|
||||
)
|
||||
node = await REFINE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
|
||||
prd_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
prd_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return prd_doc
|
||||
|
||||
|
|
@ -157,7 +152,7 @@ class WritePRD(Action):
|
|||
new_prd_doc = Document(
|
||||
root_path=PRDS_FILE_REPO,
|
||||
filename=FileRepository.new_filename() + ".json",
|
||||
content=prd.instruct_content.json(ensure_ascii=False),
|
||||
content=prd.instruct_content.model_dump_json(),
|
||||
)
|
||||
elif await self._is_relative(requirement_doc, prd_doc):
|
||||
new_prd_doc = await self._merge(requirement_doc, prd_doc)
|
||||
|
|
@ -187,18 +182,13 @@ class WritePRD(Action):
|
|||
|
||||
@staticmethod
|
||||
async def _rename_workspace(prd):
|
||||
if CONFIG.project_path: # Updating on the old version has already been specified if it's valid. According to
|
||||
# Section 2.2.3.10 of RFC 135
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
return
|
||||
|
||||
if not CONFIG.project_name:
|
||||
if isinstance(prd, (ActionOutput, ActionNode)):
|
||||
ws_name = prd.instruct_content.dict()["Project Name"]
|
||||
ws_name = prd.instruct_content.model_dump()["Project Name"]
|
||||
else:
|
||||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
CONFIG.project_name = ws_name
|
||||
if ws_name:
|
||||
CONFIG.project_name = ws_name
|
||||
CONFIG.git_repo.rename_root(CONFIG.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ REFINED_REQUIREMENTS = ActionNode(
|
|||
PROJECT_NAME = ActionNode(
|
||||
key="Project Name",
|
||||
expected_type=str,
|
||||
instruction="Name the project using snake case style, like 'game_2048' or 'simple_crm'.",
|
||||
instruction="According to the content of \"Original Requirements,\" name the project using snake case style , like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,17 +8,13 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
|
||||
|
||||
class WritePRDReview(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
prd_review_prompt_template: str = """
|
||||
|
|
|
|||
|
|
@ -6,12 +6,8 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
|
||||
REVIEW = ActionNode(
|
||||
key="Review",
|
||||
|
|
@ -38,7 +34,6 @@ class WriteReview(Action):
|
|||
"""Write a review for the given context."""
|
||||
|
||||
name: str = "WriteReview"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
188
metagpt/actions/write_teaching_plan.py
Normal file
188
metagpt/actions/write_teaching_plan.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27
|
||||
@Author : mashenquan
|
||||
@File : write_teaching_plan.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
||||
async def run(self, with_message=None, **kwargs):
|
||||
statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, [])
|
||||
statements = []
|
||||
for p in statement_patterns:
|
||||
s = self.format_value(p)
|
||||
statements.append(s)
|
||||
formatter = (
|
||||
TeachingPlanBlock.PROMPT_TITLE_TEMPLATE
|
||||
if self.topic == TeachingPlanBlock.COURSE_TITLE
|
||||
else TeachingPlanBlock.PROMPT_TEMPLATE
|
||||
)
|
||||
prompt = formatter.format(
|
||||
formation=TeachingPlanBlock.FORMATION,
|
||||
role=self.prefix,
|
||||
statements="\n".join(statements),
|
||||
lesson=self.context,
|
||||
topic=self.topic,
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
logger.debug(prompt)
|
||||
rsp = await self._aask(prompt=prompt)
|
||||
logger.debug(rsp)
|
||||
self._set_result(rsp)
|
||||
return self.rsp
|
||||
|
||||
def _set_result(self, rsp):
|
||||
if TeachingPlanBlock.DATA_BEGIN_TAG in rsp:
|
||||
ix = rsp.index(TeachingPlanBlock.DATA_BEGIN_TAG)
|
||||
rsp = rsp[ix + len(TeachingPlanBlock.DATA_BEGIN_TAG) :]
|
||||
if TeachingPlanBlock.DATA_END_TAG in rsp:
|
||||
ix = rsp.index(TeachingPlanBlock.DATA_END_TAG)
|
||||
rsp = rsp[0:ix]
|
||||
self.rsp = rsp.strip()
|
||||
if self.topic != TeachingPlanBlock.COURSE_TITLE:
|
||||
return
|
||||
if "#" not in self.rsp or self.rsp.index("#") != 0:
|
||||
self.rsp = "# " + self.rsp
|
||||
|
||||
def __str__(self):
|
||||
"""Return `topic` value when str()"""
|
||||
return self.topic
|
||||
|
||||
def __repr__(self):
|
||||
"""Show `topic` value when debug"""
|
||||
return self.topic
|
||||
|
||||
@staticmethod
|
||||
def format_value(value):
|
||||
"""Fill parameters inside `value` with `options`."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
if "{" not in value:
|
||||
return value
|
||||
|
||||
merged_opts = CONFIG.options or {}
|
||||
try:
|
||||
return value.format(**merged_opts)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Parameter is missing:{e}")
|
||||
|
||||
for k, v in merged_opts.items():
|
||||
value = value.replace("{" + f"{k}" + "}", str(v))
|
||||
return value
|
||||
|
||||
|
||||
class TeachingPlanBlock:
|
||||
FORMATION = (
|
||||
'"Capacity and role" defines the role you are currently playing;\n'
|
||||
'\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n'
|
||||
'\t"Statement" defines the work detail you need to complete at this stage;\n'
|
||||
'\t"Answer options" defines the format requirements for your responses;\n'
|
||||
'\t"Constraint" defines the conditions that your responses must comply with.'
|
||||
)
|
||||
|
||||
COURSE_TITLE = "Title"
|
||||
TOPICS = [
|
||||
COURSE_TITLE,
|
||||
"Teaching Hours",
|
||||
"Teaching Objectives",
|
||||
"Teaching Content",
|
||||
"Teaching Methods and Strategies",
|
||||
"Learning Activities",
|
||||
"Teaching Time Allocation",
|
||||
"Assessment and Feedback",
|
||||
"Teaching Summary and Improvement",
|
||||
"Vocabulary Cloze",
|
||||
"Choice Questions",
|
||||
"Grammar Questions",
|
||||
"Translation Questions",
|
||||
]
|
||||
|
||||
TOPIC_STATEMENTS = {
|
||||
COURSE_TITLE: [
|
||||
"Statement: Find and return the title of the lesson only in markdown first-level header format, "
|
||||
"without anything else."
|
||||
],
|
||||
"Teaching Content": [
|
||||
'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar '
|
||||
"structures that appear in the textbook, as well as the listening materials and key points.",
|
||||
'Statement: "Teaching Content" must include more examples.',
|
||||
],
|
||||
"Teaching Time Allocation": [
|
||||
'Statement: "Teaching Time Allocation" must include how much time is allocated to each '
|
||||
"part of the textbook content."
|
||||
],
|
||||
"Teaching Methods and Strategies": [
|
||||
'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, '
|
||||
"procedures, in detail."
|
||||
],
|
||||
"Vocabulary Cloze": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} "
|
||||
"answers, and it should also include 10 {teaching_language} questions with {language} answers. "
|
||||
"The key-related vocabulary and phrases in the textbook content must all be included in the exercises.",
|
||||
],
|
||||
"Grammar Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create grammar questions. 10 questions."
|
||||
],
|
||||
"Choice Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create choice questions. 10 questions."
|
||||
],
|
||||
"Translation Questions": [
|
||||
'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", '
|
||||
"create translation questions. The translation should include 10 {language} questions with "
|
||||
"{teaching_language} answers, and it should also include 10 {teaching_language} questions with "
|
||||
"{language} answers."
|
||||
],
|
||||
}
|
||||
|
||||
# Teaching plan title
|
||||
PROMPT_TITLE_TEMPLATE = (
|
||||
"Do not refer to the context of the previous conversation records, "
|
||||
"start the conversation anew.\n\n"
|
||||
"Formation: {formation}\n\n"
|
||||
"{statements}\n"
|
||||
"Constraint: Writing in {language}.\n"
|
||||
'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" '
|
||||
'and "[TEACHING_PLAN_END]" tags.\n'
|
||||
"[LESSON_BEGIN]\n"
|
||||
"{lesson}\n"
|
||||
"[LESSON_END]"
|
||||
)
|
||||
|
||||
# Teaching plan parts:
|
||||
PROMPT_TEMPLATE = (
|
||||
"Do not refer to the context of the previous conversation records, "
|
||||
"start the conversation anew.\n\n"
|
||||
"Formation: {formation}\n\n"
|
||||
"Capacity and role: {role}\n"
|
||||
'Statement: Write the "{topic}" part of teaching plan, '
|
||||
'WITHOUT ANY content unrelated to "{topic}"!!\n'
|
||||
"{statements}\n"
|
||||
'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" '
|
||||
'and "[TEACHING_PLAN_END]" tags.\n'
|
||||
"Answer options: Using proper markdown format from second-level header format.\n"
|
||||
"Constraint: Writing in {language}.\n"
|
||||
"[LESSON_BEGIN]\n"
|
||||
"{lesson}\n"
|
||||
"[LESSON_END]"
|
||||
)
|
||||
|
||||
DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]"
|
||||
DATA_END_TAG = "[TEACHING_PLAN_END]"
|
||||
|
|
@ -10,14 +10,10 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document, TestingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -44,8 +40,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
|
||||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
context: Optional[TestingContext] = None
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,8 @@
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
||||
|
||||
|
|
@ -27,7 +23,6 @@ class WriteDirectory(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteDirectory"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
language: str = "Chinese"
|
||||
|
||||
async def run(self, topic: str, *args, **kwargs) -> Dict:
|
||||
|
|
@ -54,7 +49,6 @@ class WriteContent(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteContent"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
directory: dict = dict()
|
||||
language: str = "Chinese"
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@ Provide configuration, singleton
|
|||
1. According to Section 2.2.3.11 of RFC 135, add git repository support.
|
||||
2. Add the parameter `src_workspace` for the old version project path.
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -19,6 +22,7 @@ from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
|||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.common import require_python_version
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
|
|
@ -42,6 +46,8 @@ class LLMProviderEnum(Enum):
|
|||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
METAGPT = "metagpt"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
|
|
@ -58,7 +64,7 @@ class Config(metaclass=Singleton):
|
|||
key_yaml_file = METAGPT_ROOT / "config/key.yaml"
|
||||
default_yaml_file = METAGPT_ROOT / "config/config.yaml"
|
||||
|
||||
def __init__(self, yaml_file=default_yaml_file):
|
||||
def __init__(self, yaml_file=default_yaml_file, cost_data=""):
|
||||
global_options = OPTIONS.get()
|
||||
# cli paras
|
||||
self.project_path = ""
|
||||
|
|
@ -66,34 +72,64 @@ class Config(metaclass=Singleton):
|
|||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
self.git_reinit = False
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
self.cost_manager = CostManager(**json.loads(cost_data)) if cost_data else CostManager()
|
||||
self._update()
|
||||
global_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self) -> LLMProviderEnum:
|
||||
for k, v in [
|
||||
(self.openai_api_key, LLMProviderEnum.OPENAI),
|
||||
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
|
||||
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
|
||||
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
|
||||
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
|
||||
(self.gemini_api_key, LLMProviderEnum.GEMINI),
|
||||
(self.ollama_api_base, LLMProviderEnum.OLLAMA), # reuse logic. but not a key
|
||||
]:
|
||||
if self._is_valid_llm_key(k):
|
||||
# logger.debug(f"Use LLMProvider: {v.value}")
|
||||
if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
if self.openai_api_key and self.openai_api_model:
|
||||
logger.info(f"OpenAI API Model: {self.openai_api_model}")
|
||||
return v
|
||||
"""Get first valid LLM provider enum"""
|
||||
mappings = {
|
||||
LLMProviderEnum.OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL
|
||||
),
|
||||
LLMProviderEnum.ANTHROPIC: self._is_valid_llm_key(self.ANTHROPIC_API_KEY),
|
||||
LLMProviderEnum.ZHIPUAI: self._is_valid_llm_key(self.ZHIPUAI_API_KEY),
|
||||
LLMProviderEnum.FIREWORKS: self._is_valid_llm_key(self.FIREWORKS_API_KEY),
|
||||
LLMProviderEnum.OPEN_LLM: self._is_valid_llm_key(self.OPEN_LLM_API_BASE),
|
||||
LLMProviderEnum.GEMINI: self._is_valid_llm_key(self.GEMINI_API_KEY),
|
||||
LLMProviderEnum.METAGPT: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY) and self.OPENAI_API_TYPE == "metagpt"
|
||||
),
|
||||
LLMProviderEnum.AZURE_OPENAI: bool(
|
||||
self._is_valid_llm_key(self.OPENAI_API_KEY)
|
||||
and self.OPENAI_API_TYPE == "azure"
|
||||
and self.DEPLOYMENT_NAME
|
||||
and self.OPENAI_API_VERSION
|
||||
),
|
||||
LLMProviderEnum.OLLAMA: self._is_valid_llm_key(self.OLLAMA_API_BASE),
|
||||
}
|
||||
provider = None
|
||||
for k, v in mappings.items():
|
||||
if v:
|
||||
provider = k
|
||||
break
|
||||
|
||||
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
model_name = self.get_model_name(provider=provider)
|
||||
if model_name:
|
||||
logger.info(f"{provider} Model: {model_name}")
|
||||
if provider:
|
||||
logger.info(f"API: {provider}")
|
||||
return provider
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
||||
def get_model_name(self, provider=None) -> str:
|
||||
provider = provider or self.get_default_llm_provider_enum()
|
||||
model_mappings = {
|
||||
LLMProviderEnum.OPENAI: self.OPENAI_API_MODEL,
|
||||
LLMProviderEnum.AZURE_OPENAI: self.DEPLOYMENT_NAME,
|
||||
}
|
||||
return model_mappings.get(provider, "")
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k: str) -> bool:
|
||||
return k and k != "YOUR_API_KEY"
|
||||
return bool(k and k != "YOUR_API_KEY")
|
||||
|
||||
def _update(self):
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
|
@ -142,8 +178,7 @@ class Config(metaclass=Singleton):
|
|||
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.code_review_k_times = 2
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
|
|
@ -154,10 +189,18 @@ class Config(metaclass=Singleton):
|
|||
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
|
||||
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
|
||||
|
||||
workspace_uid = (
|
||||
self._get("WORKSPACE_UID") or f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}"
|
||||
)
|
||||
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
|
||||
self.prompt_schema = self._get("PROMPT_FORMAT", "json")
|
||||
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
|
||||
val = self._get("WORKSPACE_PATH_WITH_UID")
|
||||
if val and val.lower() == "true": # for agent
|
||||
self.workspace_path = self.workspace_path / workspace_uid
|
||||
self._ensure_workspace_exists()
|
||||
self.max_auto_summarize_code = self.max_auto_summarize_code or self._get("MAX_AUTO_SUMMARIZE_CODE", 1)
|
||||
self.timeout = int(self._get("TIMEOUT", 3))
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
|
@ -198,7 +241,8 @@ class Config(metaclass=Singleton):
|
|||
return i.get(*args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
"""Search for a value in config/key.yaml, config/config.yaml, and env; raise an error if not found"""
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
|
||||
Throw an error if not found."""
|
||||
value = self._get(key, *args, **kwargs)
|
||||
if value is None:
|
||||
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
|
||||
|
|
|
|||
|
|
@ -48,10 +48,12 @@ def get_metagpt_root():
|
|||
|
||||
# METAGPT PROJECT ROOT AND VARS
|
||||
|
||||
METAGPT_ROOT = get_metagpt_root()
|
||||
METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
||||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
|
||||
|
|
@ -100,7 +102,27 @@ TEST_CODES_FILE_REPO = "tests"
|
|||
TEST_OUTPUTS_FILE_REPO = "test_outputs"
|
||||
CODE_SUMMARIES_FILE_REPO = "docs/code_summaries"
|
||||
CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summaries"
|
||||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/SD_Output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_views"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
||||
DEFAULT_LANGUAGE = "English"
|
||||
DEFAULT_MAX_TOKENS = 1500
|
||||
COMMAND_TOKENS = 500
|
||||
BRAIN_MEMORY = "BRAIN_MEMORY"
|
||||
SKILL_PATH = "SKILL_PATH"
|
||||
SERPER_API_KEY = "SERPER_API_KEY"
|
||||
DEFAULT_TOKEN_SIZE = 500
|
||||
|
||||
# format
|
||||
BASE64_FORMAT = "base64"
|
||||
|
||||
# REDIS
|
||||
REDIS_KEY = "REDIS_KEY"
|
||||
LLM_API_TIMEOUT = 300
|
||||
|
||||
# Message id
|
||||
IGNORED_MESSAGE_ID = "0"
|
||||
|
|
|
|||
|
|
@ -17,11 +17,9 @@ from langchain.document_loaders import (
|
|||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
|
||||
|
||||
|
|
@ -103,6 +101,7 @@ class Document(BaseModel):
|
|||
raise ValueError("File path is not set.")
|
||||
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# TODO: excel, csv, json, etc.
|
||||
self.path.write_text(self.content, encoding="utf-8")
|
||||
|
||||
def persist(self):
|
||||
|
|
@ -117,22 +116,23 @@ class IndexableDocument(Document):
|
|||
Advanced document handling: For vector databases or search engines.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
data: Union[pd.DataFrame, list]
|
||||
content_col: Optional[str] = Field(default="")
|
||||
meta_col: Optional[str] = Field(default="")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_path: Path, content_col="content", meta_col="metadata"):
|
||||
if not data_path.exists():
|
||||
raise FileNotFoundError(f"File {data_path} not found.")
|
||||
data = read_data(data_path)
|
||||
content = data_path.read_text()
|
||||
if isinstance(data, pd.DataFrame):
|
||||
validate_cols(content_col, data)
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
|
||||
else:
|
||||
content = data_path.read_text()
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
|
||||
def _get_docs_and_metadatas_by_df(self) -> (list, list):
|
||||
df = self.data
|
||||
|
|
@ -214,7 +214,7 @@ class Repo(BaseModel):
|
|||
self.assets[path] = doc
|
||||
return doc
|
||||
|
||||
def set(self, content: str, filename: str):
|
||||
def set(self, filename: str, content: str):
|
||||
"""Set a document and persist it to disk."""
|
||||
path = self._path(filename)
|
||||
doc = self._set(content, path)
|
||||
|
|
@ -233,24 +233,3 @@ class Repo(BaseModel):
|
|||
n_chars = sum(sum(len(j.content) for j in i.values()) for i in [self.docs, self.codes, self.assets])
|
||||
symbols = RepoParser(base_directory=self.path).generate_symbols()
|
||||
return RepoMetadata(name=self.name, n_docs=n_docs, n_chars=n_chars, symbols=symbols)
|
||||
|
||||
|
||||
def set_existing_repo(path=CONFIG.workspace_path / "t1"):
|
||||
repo1 = Repo.from_path(path)
|
||||
repo1.set("wtf content", "doc/wtf_file.md")
|
||||
repo1.set("wtf code", "code/wtf_file.py")
|
||||
logger.info(repo1) # check doc
|
||||
|
||||
|
||||
def load_existing_repo(path=CONFIG.workspace_path / "web_tetris"):
|
||||
repo = Repo.from_path(path)
|
||||
logger.info(repo)
|
||||
logger.info(repo.eda())
|
||||
|
||||
|
||||
def main():
|
||||
load_existing_repo()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,81 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/6/8 14:03
|
||||
@Author : alexanderwu
|
||||
@File : document.py
|
||||
@Desc : Classes and Operations Related to Vector Files in the Vector Database. Still under design.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from langchain.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def validate_cols(content_col: str, df: pd.DataFrame):
|
||||
if content_col not in df.columns:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def read_data(data_path: Path):
|
||||
suffix = data_path.suffix
|
||||
if ".xlsx" == suffix:
|
||||
data = pd.read_excel(data_path)
|
||||
elif ".csv" == suffix:
|
||||
data = pd.read_csv(data_path)
|
||||
elif ".json" == suffix:
|
||||
data = pd.read_json(data_path)
|
||||
elif suffix in (".docx", ".doc"):
|
||||
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
|
||||
elif ".txt" == suffix:
|
||||
data = TextLoader(str(data_path)).load()
|
||||
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(data)
|
||||
data = texts
|
||||
elif ".pdf" == suffix:
|
||||
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return data
|
||||
|
||||
|
||||
class Document:
|
||||
def __init__(self, data_path, content_col="content", meta_col="metadata"):
|
||||
self.data = read_data(data_path)
|
||||
if isinstance(self.data, pd.DataFrame):
|
||||
validate_cols(content_col, self.data)
|
||||
self.content_col = content_col
|
||||
self.meta_col = meta_col
|
||||
|
||||
def _get_docs_and_metadatas_by_df(self) -> (list, list):
|
||||
df = self.data
|
||||
docs = []
|
||||
metadatas = []
|
||||
for i in tqdm(range(len(df))):
|
||||
docs.append(df[self.content_col].iloc[i])
|
||||
if self.meta_col:
|
||||
metadatas.append({self.meta_col: df[self.meta_col].iloc[i]})
|
||||
else:
|
||||
metadatas.append({})
|
||||
|
||||
return docs, metadatas
|
||||
|
||||
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
|
||||
data = self.data
|
||||
docs = [i.page_content for i in data]
|
||||
metadatas = [i.metadata for i in data]
|
||||
return docs, metadatas
|
||||
|
||||
def get_docs_and_metadatas(self) -> (list, list):
|
||||
if isinstance(self.data, pd.DataFrame):
|
||||
return self._get_docs_and_metadatas_by_df()
|
||||
elif isinstance(self.data, list):
|
||||
return self._get_docs_and_metadatas_by_langchain()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
@ -13,7 +13,7 @@ from langchain.embeddings import OpenAIEmbeddings
|
|||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -25,7 +25,9 @@ class FaissStore(LocalStore):
|
|||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.embedding = embedding or OpenAIEmbeddings(
|
||||
openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url
|
||||
)
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
|
|
@ -73,10 +75,3 @@ class FaissStore(LocalStore):
|
|||
def delete(self, *args, **kwargs):
|
||||
"""Currently, langchain does not provide a delete interface."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
|
||||
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))
|
||||
faiss_store.add([f"Oily Skin Facial Cleanser-{i}" for i in range(3)])
|
||||
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/28 00:00
|
||||
@Author : alexanderwu
|
||||
@File : milvus_store.py
|
||||
"""
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
|
||||
|
||||
|
||||
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
|
||||
"""Assume the structure of columns is str: regular type"""
|
||||
fields = []
|
||||
for col, ctype in columns.items():
|
||||
if ctype == str:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], max_length=100)
|
||||
elif ctype == np.ndarray:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], dim=2)
|
||||
else:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], is_primary=(col == primary_col_name))
|
||||
fields.append(mcol)
|
||||
schema = CollectionSchema(fields, description=desc)
|
||||
return schema
|
||||
|
||||
|
||||
class MilvusConnection(TypedDict):
|
||||
alias: str
|
||||
host: str
|
||||
port: str
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
connections.connect(**connection)
|
||||
self.collection = None
|
||||
|
||||
def _create_collection(self, name, schema):
|
||||
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
|
||||
return collection
|
||||
|
||||
def create_collection(self, name, columns):
|
||||
schema = columns_to_milvus_schema(columns, "idx")
|
||||
self.collection = self._create_collection(name, schema)
|
||||
return self.collection
|
||||
|
||||
def drop(self, name):
|
||||
Collection(name).drop()
|
||||
|
||||
def load_collection(self):
|
||||
self.collection.load()
|
||||
|
||||
def build_index(self, field="emb"):
|
||||
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
|
||||
|
||||
def search(self, query: list[list[float]], *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/search.md
|
||||
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
|
||||
Note the above description, is this logic serious? This should take a long time, right?
|
||||
"""
|
||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
||||
results = self.collection.search(
|
||||
data=query,
|
||||
anns_field=kwargs.get("field", "emb"),
|
||||
param=search_params,
|
||||
limit=10,
|
||||
expr=None,
|
||||
consistency_level="Strong",
|
||||
)
|
||||
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
|
||||
return results
|
||||
|
||||
def write(self, name, schema, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, data, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/insert_data.md
|
||||
import random
|
||||
data = [
|
||||
[i for i in range(2000)],
|
||||
[i for i in range(10000, 12000)],
|
||||
[[random.random() for _ in range(2)] for _ in range(2000)],
|
||||
]
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.collection.insert(data)
|
||||
|
|
@ -15,10 +15,11 @@ import asyncio
|
|||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, role_subclass_registry
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
|
||||
|
||||
|
|
@ -28,30 +29,17 @@ class Environment(BaseModel):
|
|||
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, Role] = Field(default_factory=dict)
|
||||
members: dict[Role, Set] = Field(default_factory=dict)
|
||||
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
|
||||
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
roles = []
|
||||
for role_key, role in kwargs.get("roles", {}).items():
|
||||
current_role = kwargs["roles"][role_key]
|
||||
if isinstance(current_role, dict):
|
||||
item_class_name = current_role.get("builtin_class_name", None)
|
||||
for name, subclass in role_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_role = subclass(**current_role)
|
||||
break
|
||||
kwargs["roles"][role_key] = current_role
|
||||
roles.append(current_role)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.add_roles(roles) # add_roles again to init the Role.set_env
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
|
|
@ -108,7 +96,7 @@ class Environment(BaseModel):
|
|||
for role in roles: # setup system message with roles
|
||||
role.set_env(self)
|
||||
|
||||
def publish_message(self, message: Message) -> bool:
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
"""
|
||||
Distribute the message to the recipients.
|
||||
In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned
|
||||
|
|
@ -173,3 +161,8 @@ class Environment(BaseModel):
|
|||
def set_subscription(self, obj, tags):
|
||||
"""Set the labels for message to be consumed by the object"""
|
||||
self.members[obj] = tags
|
||||
|
||||
@staticmethod
|
||||
def archive(auto_archive=True):
|
||||
if auto_archive and CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
|
|
|
|||
|
|
@ -5,3 +5,9 @@
|
|||
@Author : alexanderwu
|
||||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
__all__ = ["text_to_image", "text_to_speech", "google_search"]
|
||||
|
|
|
|||
12
metagpt/learn/google_search.py
Normal file
12
metagpt/learn/google_search.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def google_search(query: str, max_results: int = 6, **kwargs):
|
||||
"""Perform a web search and retrieve search results.
|
||||
|
||||
:param query: The search query.
|
||||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
results = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(results, 1))
|
||||
100
metagpt/learn/skill_loader.py
Normal file
100
metagpt/learn/skill_loader.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : skill_loader.py
|
||||
@Desc : Skill YAML Configuration Loader.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class Example(BaseModel):
|
||||
ask: str
|
||||
answer: str
|
||||
|
||||
|
||||
class Returns(BaseModel):
|
||||
type: str
|
||||
format: Optional[str] = None
|
||||
|
||||
|
||||
class Parameter(BaseModel):
|
||||
type: str
|
||||
description: str = None
|
||||
|
||||
|
||||
class Skill(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
id: str = None
|
||||
x_prerequisite: Dict = Field(default=None, alias="x-prerequisite")
|
||||
parameters: Dict[str, Parameter] = None
|
||||
examples: List[Example]
|
||||
returns: Returns
|
||||
|
||||
@property
|
||||
def arguments(self) -> Dict:
|
||||
if not self.parameters:
|
||||
return {}
|
||||
ret = {}
|
||||
for k, v in self.parameters.items():
|
||||
ret[k] = v.description if v.description else ""
|
||||
return ret
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
name: str = None
|
||||
skills: List[Skill]
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class SkillsDeclaration(BaseModel):
|
||||
skillapi: str
|
||||
entities: Dict[str, Entity]
|
||||
components: Components = None
|
||||
|
||||
@staticmethod
|
||||
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
|
||||
if not skill_yaml_file_name:
|
||||
skill_yaml_file_name = Path(__file__).parent.parent.parent / ".well-known/skills.yaml"
|
||||
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
|
||||
data = await reader.read(-1)
|
||||
skill_data = yaml.safe_load(data)
|
||||
return SkillsDeclaration(**skill_data)
|
||||
|
||||
def get_skill_list(self, entity_name: str = "Assistant") -> Dict:
|
||||
"""Return the skill name based on the skill description."""
|
||||
entity = self.entities.get(entity_name)
|
||||
if not entity:
|
||||
return {}
|
||||
|
||||
# List of skills that the agent chooses to activate.
|
||||
agent_skills = CONFIG.agent_skills
|
||||
if not agent_skills:
|
||||
return {}
|
||||
|
||||
class _AgentSkill(BaseModel):
|
||||
name: str
|
||||
|
||||
names = [_AgentSkill(**i).name for i in agent_skills]
|
||||
return {s.description: s.name for s in entity.skills if s.name in names}
|
||||
|
||||
def get_skill(self, name, entity_name: str = "Assistant") -> Skill:
|
||||
"""Return a skill by name."""
|
||||
entity = self.entities.get(entity_name)
|
||||
if not entity:
|
||||
return None
|
||||
for sk in entity.skills:
|
||||
if sk.name == name:
|
||||
return sk
|
||||
24
metagpt/learn/text_to_embedding.py
Normal file
24
metagpt/learn/text_to_embedding.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : text_to_embedding.py
|
||||
@Desc : Text-to-Embedding skill, which provides text-to-embedding functionality.
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
|
||||
|
||||
|
||||
async def text_to_embedding(text, model="text-embedding-ada-002", openai_api_key="", **kwargs):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
if CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
return await oas3_openai_text_to_embedding(text, model=model, openai_api_key=openai_api_key)
|
||||
raise EnvironmentError
|
||||
40
metagpt/learn/text_to_image.py
Normal file
40
metagpt/learn/text_to_image.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : text_to_image.py
|
||||
@Desc : Text-to-Image skill, which provides text-to-image functionality.
|
||||
"""
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:param size_type: If using OPENAI, the available size options are ['256x256', '512x512', '1024x1024'], while for MetaGPT, the options are ['512x512', '512x768'].
|
||||
:param model_url: MetaGPT model url
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
binary_data = await oas3_openai_text_to_image(text, size_type)
|
||||
else:
|
||||
raise ValueError("Missing necessary parameters.")
|
||||
base64_data = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".png", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f""
|
||||
return image_declaration + base64_data if base64_data else ""
|
||||
70
metagpt/learn/text_to_speech.py
Normal file
70
metagpt/learn/text_to_speech.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.azure_tts import oas3_azsure_tts
|
||||
from metagpt.tools.iflytek_tts import oas3_iflytek_tts
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text,
|
||||
lang="zh-CN",
|
||||
voice="zh-CN-XiaomoNeural",
|
||||
style="affectionate",
|
||||
role="Girl",
|
||||
subscription_key="",
|
||||
region="",
|
||||
iflytek_app_id="",
|
||||
iflytek_api_key="",
|
||||
iflytek_api_secret="",
|
||||
**kwargs,
|
||||
):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
|
||||
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param text: The text used for voice conversion.
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
:param iflytek_app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param iflytek_api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param iflytek_api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:return: Returns the Base64-encoded .wav/.mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
|
||||
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
|
||||
audio_declaration = "data:audio/wav;base64,"
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".wav", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
if (CONFIG.IFLYTEK_APP_ID and CONFIG.IFLYTEK_API_KEY and CONFIG.IFLYTEK_API_SECRET) or (
|
||||
iflytek_app_id and iflytek_api_key and iflytek_api_secret
|
||||
):
|
||||
audio_declaration = "data:audio/mp3;base64,"
|
||||
base64_data = await oas3_iflytek_tts(
|
||||
text=text, app_id=iflytek_app_id, api_key=iflytek_api_key, api_secret=iflytek_api_secret
|
||||
)
|
||||
s3 = S3()
|
||||
url = await s3.cache(data=base64_data, file_ext=".mp3", format=BASE64_FORMAT) if s3.is_valid else ""
|
||||
if url:
|
||||
return f"[{text}]({url})"
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
raise ValueError(
|
||||
"AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error"
|
||||
)
|
||||
|
|
@ -9,14 +9,14 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI:
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
if provider is None:
|
||||
provider = CONFIG.get_default_llm_provider_enum()
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@
|
|||
@Time : 2023/6/5 01:44
|
||||
@Author : alexanderwu
|
||||
@File : skill_manager.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove useless `llm`
|
||||
"""
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import PROMPT_PATH
|
||||
from metagpt.document_store.chromadb_store import ChromaStore
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
Skill = Action
|
||||
|
|
@ -18,7 +18,6 @@ class SkillManager:
|
|||
"""Used to manage all skills"""
|
||||
|
||||
def __init__(self):
|
||||
self._llm = LLM()
|
||||
self._store = ChromaStore("skill_manager")
|
||||
self._skills: dict[str:Skill] = {}
|
||||
|
||||
|
|
@ -29,7 +28,7 @@ class SkillManager:
|
|||
:return:
|
||||
"""
|
||||
self._skills[skill.name] = skill
|
||||
self._store.add(skill.desc, {}, skill.name)
|
||||
self._store.add(skill.desc, {"name": skill.name, "desc": skill.desc}, skill.name)
|
||||
|
||||
def del_skill(self, skill_name: str):
|
||||
"""
|
||||
|
|
|
|||
331
metagpt/memory/brain_memory.py
Normal file
331
metagpt/memory/brain_memory.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : brain_memory.py
|
||||
@Desc : Used by AgentStore. Used for long-term storage and automatic compression.
|
||||
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
|
||||
@Modified By: mashenquan, 2023/12/25. Simplify Functionality.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTLLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, SimpleMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
||||
class BrainMemory(BaseModel):
|
||||
history: List[Message] = Field(default_factory=list)
|
||||
knowledge: List[Message] = Field(default_factory=list)
|
||||
historical_summary: str = ""
|
||||
last_history_id: str = ""
|
||||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
cacheable: bool = True
|
||||
llm: Optional[BaseLLM] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
"""
|
||||
Add message from user.
|
||||
"""
|
||||
msg.role = "user"
|
||||
self.add_history(msg)
|
||||
self.is_dirty = True
|
||||
|
||||
def add_answer(self, msg: Message):
|
||||
"""Add message from LLM"""
|
||||
msg.role = "assistant"
|
||||
self.add_history(msg)
|
||||
self.is_dirty = True
|
||||
|
||||
def get_knowledge(self) -> str:
|
||||
texts = [m.content for m in self.knowledge]
|
||||
return "\n".join(texts)
|
||||
|
||||
@staticmethod
|
||||
async def loads(redis_key: str) -> "BrainMemory":
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
return BrainMemory()
|
||||
v = await redis.get(key=redis_key)
|
||||
logger.debug(f"REDIS GET {redis_key} {v}")
|
||||
if v:
|
||||
bm = BrainMemory.parse_raw(v)
|
||||
bm.is_dirty = False
|
||||
return bm
|
||||
return BrainMemory()
|
||||
|
||||
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60):
|
||||
if not self.is_dirty:
|
||||
return
|
||||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
return False
|
||||
v = self.model_dump_json()
|
||||
if self.cacheable:
|
||||
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
|
||||
logger.debug(f"REDIS SET {redis_key} {v}")
|
||||
self.is_dirty = False
|
||||
|
||||
@staticmethod
|
||||
def to_redis_key(prefix: str, user_id: str, chat_id: str):
|
||||
return f"{prefix}:{user_id}:{chat_id}"
|
||||
|
||||
async def set_history_summary(self, history_summary, redis_key, redis_conf):
|
||||
if self.historical_summary == history_summary:
|
||||
if self.is_dirty:
|
||||
await self.dumps(redis_key=redis_key)
|
||||
self.is_dirty = False
|
||||
return
|
||||
|
||||
self.historical_summary = history_summary
|
||||
self.history = []
|
||||
await self.dumps(redis_key=redis_key)
|
||||
self.is_dirty = False
|
||||
|
||||
def add_history(self, msg: Message):
|
||||
if msg.id:
|
||||
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
|
||||
return
|
||||
|
||||
self.history.append(msg)
|
||||
self.last_history_id = str(msg.id)
|
||||
self.is_dirty = True
|
||||
|
||||
def exists(self, text) -> bool:
|
||||
for m in reversed(self.history):
|
||||
if m.content == text:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def to_int(v, default_value):
|
||||
try:
|
||||
return int(v)
|
||||
except:
|
||||
return default_value
|
||||
|
||||
def pop_last_talk(self):
|
||||
v = self.last_talk
|
||||
self.last_talk = None
|
||||
return v
|
||||
|
||||
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_summarize(max_words=max_words)
|
||||
|
||||
self.llm = llm
|
||||
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
|
||||
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
|
||||
texts = [self.historical_summary]
|
||||
for m in self.history:
|
||||
texts.append(m.content)
|
||||
text = "\n".join(texts)
|
||||
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
if summary:
|
||||
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
|
||||
return summary
|
||||
raise ValueError(f"text too long:{text_length}")
|
||||
|
||||
async def _metagpt_summarize(self, max_words=200):
|
||||
if not self.history:
|
||||
return ""
|
||||
|
||||
total_length = 0
|
||||
msgs = []
|
||||
for m in reversed(self.history):
|
||||
delta = len(m.content)
|
||||
if total_length + delta > max_words:
|
||||
left = max_words - total_length
|
||||
if left == 0:
|
||||
break
|
||||
m.content = m.content[0:left]
|
||||
msgs.append(m)
|
||||
break
|
||||
msgs.append(m)
|
||||
total_length += delta
|
||||
msgs.reverse()
|
||||
self.history = msgs
|
||||
self.is_dirty = True
|
||||
await self.dumps(redis_key=CONFIG.REDIS_KEY)
|
||||
self.is_dirty = False
|
||||
|
||||
return BrainMemory.to_metagpt_history_format(self.history)
|
||||
|
||||
@staticmethod
|
||||
def to_metagpt_history_format(history) -> str:
|
||||
mmsg = [SimpleMessage(role=m.role, content=m.content).model_dump() for m in history]
|
||||
return json.dumps(mmsg, ensure_ascii=False)
|
||||
|
||||
async def get_title(self, llm, max_words=5, **kwargs) -> str:
|
||||
"""Generate text title"""
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return self.history[0].content if self.history else "New"
|
||||
|
||||
summary = await self.summarize(llm=llm, max_words=500)
|
||||
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
async def is_related(self, text1, text2, llm):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
|
||||
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
|
||||
|
||||
@staticmethod
|
||||
async def _metagpt_is_related(**kwargs):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _openai_is_related(text1, text2, llm, **kwargs):
|
||||
command = (
|
||||
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
|
||||
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
|
||||
return result
|
||||
|
||||
async def rewrite(self, sentence: str, context: str, llm):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
|
||||
@staticmethod
|
||||
async def _metagpt_rewrite(sentence: str, **kwargs):
|
||||
return sentence
|
||||
|
||||
@staticmethod
|
||||
async def _openai_rewrite(sentence: str, context: str, llm):
|
||||
command = (
|
||||
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
|
||||
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
else:
|
||||
return None, input_string
|
||||
|
||||
@property
|
||||
def is_history_available(self):
|
||||
return bool(self.history or self.historical_summary)
|
||||
|
||||
@property
|
||||
def history_text(self):
|
||||
if len(self.history) == 0 and not self.historical_summary:
|
||||
return ""
|
||||
texts = [self.historical_summary] if self.historical_summary else []
|
||||
for m in self.history[:-1]:
|
||||
if isinstance(m, Dict):
|
||||
t = Message(**m).content
|
||||
elif isinstance(m, Message):
|
||||
t = m.content
|
||||
else:
|
||||
continue
|
||||
texts.append(t)
|
||||
|
||||
return "\n".join(texts)
|
||||
|
||||
async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = ""
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
break
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
summary = summaries[0]
|
||||
break
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
return summary
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
|
|
@ -2,15 +2,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : the implement of Long-term memory
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
|
|
@ -21,14 +23,13 @@ class LongTermMemory(Memory):
|
|||
- update memory when it changed
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_storage: MemoryStorage = Field(default_factory=MemoryStorage)
|
||||
rc: Optional["RoleContext"] = None
|
||||
rc: Optional[RoleContext] = None
|
||||
msg_from_recover: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def recover_memory(self, role_id: str, rc: "RoleContext"):
|
||||
def recover_memory(self, role_id: str, rc: RoleContext):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@
|
|||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
|
|
@ -24,22 +25,14 @@ from metagpt.utils.common import (
|
|||
class Memory(BaseModel):
|
||||
"""The most basic memory: super-memory"""
|
||||
|
||||
storage: list[Message] = []
|
||||
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
index = kwargs.get("index", {})
|
||||
new_index = defaultdict(list)
|
||||
for action_str, value in index.items():
|
||||
new_index[action_str] = [Message(**item_dict) for item_dict in value]
|
||||
kwargs["index"] = new_index
|
||||
super(Memory, self).__init__(**kwargs)
|
||||
self.index = new_index
|
||||
storage: list[SerializeAsAny[Message]] = []
|
||||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.dict()
|
||||
storage = self.model_dump()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -54,6 +47,8 @@ class Memory(BaseModel):
|
|||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
message.id = IGNORED_MESSAGE_ID
|
||||
if message in self.storage:
|
||||
return
|
||||
self.storage.append(message)
|
||||
|
|
@ -84,6 +79,8 @@ class Memory(BaseModel):
|
|||
|
||||
def delete(self, message: Message):
|
||||
"""Delete the specified message from storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
message.id = IGNORED_MESSAGE_ID
|
||||
self.storage.remove(message)
|
||||
if message.cause_by and message in self.index[message.cause_by]:
|
||||
self.index[message.cause_by].remove(message)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the implement of memory storage
|
||||
"""
|
||||
@Desc : the implement of memory storage
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
|
@ -19,20 +24,30 @@ class MemoryStorage(FaissStore):
|
|||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def recover_memory(self, role_id: str) -> List[Message]:
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -49,16 +64,16 @@ class MemoryStorage(FaissStore):
|
|||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self):
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
super().persist()
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
|
|
@ -74,7 +89,7 @@ class MemoryStorage(FaissStore):
|
|||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> List[Message]:
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/30 10:09
|
||||
@Author : alexanderwu
|
||||
@File : decompose.py
|
||||
"""
|
||||
|
||||
DECOMPOSE_SYSTEM = """SYSTEM:
|
||||
You serve as an assistant that helps me play Minecraft.
|
||||
I will give you my goal in the game, please break it down as a tree-structure plan to achieve this goal.
|
||||
The requirements of the tree-structure plan are:
|
||||
1. The plan tree should be exactly of depth 2.
|
||||
2. Describe each step in one line.
|
||||
3. You should index the two levels like ’1.’, ’1.1.’, ’1.2.’, ’2.’, ’2.1.’, etc.
|
||||
4. The sub-goals at the bottom level should be basic actions so that I can easily execute them in the game.
|
||||
"""
|
||||
|
||||
|
||||
DECOMPOSE_USER = """USER:
|
||||
The goal is to {goal description}. Generate the plan according to the requirements.
|
||||
"""
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/30 10:12
|
||||
@Author : alexanderwu
|
||||
@File : structure_action.py
|
||||
"""
|
||||
|
||||
ACTION_SYSTEM = """SYSTEM:
|
||||
You serve as an assistant that helps me play Minecraft.
|
||||
I will give you a sentence. Please convert this sentence into one or several actions according to the following instructions.
|
||||
Each action should be a tuple of four items, written in the form (’verb’, ’object’, ’tools’, ’materials’)
|
||||
’verb’ is the verb of this action.
|
||||
’object’ refers to the target object of the action.
|
||||
’tools’ specifies the tools required for the action.
|
||||
’material’ specifies the materials required for the action.
|
||||
If some of the items are not required, set them to be ’None’.
|
||||
"""
|
||||
|
||||
ACTION_USER = """USER:
|
||||
The sentence is {sentence}. Generate the action tuple according to the requirements.
|
||||
"""
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/30 09:51
|
||||
@Author : alexanderwu
|
||||
@File : structure_goal.py
|
||||
"""
|
||||
|
||||
GOAL_SYSTEM = """SYSTEM:
|
||||
You are an assistant for the game Minecraft.
|
||||
I will give you some target object and some knowledge related to the object. Please write the obtaining of the object as a goal in the standard form.
|
||||
The standard form of the goal is as follows:
|
||||
{
|
||||
"object": "the name of the target object",
|
||||
"count": "the target quantity",
|
||||
"material": "the materials required for this goal, a dictionary in the form {material_name: material_quantity}. If no material is required, set it to None",
|
||||
"tool": "the tool used for this goal. If multiple tools can be used for this goal, only write the most basic one. If no tool is required, set it to None",
|
||||
"info": "the knowledge related to this goal"
|
||||
}
|
||||
The information I will give you:
|
||||
Target object: the name and the quantity of the target object
|
||||
Knowledge: some knowledge related to the object.
|
||||
Requirements:
|
||||
1. You must generate the goal based on the provided knowledge instead of purely depending on your own knowledge.
|
||||
2. The "info" should be as compact as possible, at most 3 sentences. The knowledge I give you may be raw texts from Wiki documents. Please extract and summarize important information instead of directly copying all the texts.
|
||||
Goal Example:
|
||||
{
|
||||
"object": "iron_ore",
|
||||
"count": 1,
|
||||
"material": None,
|
||||
"tool": "stone_pickaxe",
|
||||
"info": "iron ore is obtained by mining iron ore. iron ore is most found in level 53. iron ore can only be mined with a stone pickaxe or better; using a wooden or gold pickaxe will yield nothing."
|
||||
}
|
||||
{
|
||||
"object": "wooden_pickaxe",
|
||||
"count": 1,
|
||||
"material": {"planks": 3, "stick": 2},
|
||||
"tool": "crafting_table",
|
||||
"info": "wooden pickaxe can be crafted with 3 planks and 2 stick as the material and crafting table as the tool."
|
||||
}
|
||||
"""
|
||||
|
||||
GOAL_USER = """USER:
|
||||
Target object: {object quantity} {object name}
|
||||
Knowledge: {related knowledge}
|
||||
"""
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/30 10:45
|
||||
@Author : alexanderwu
|
||||
@File : use_lib_sop.py
|
||||
"""
|
||||
|
||||
SOP_SYSTEM = """SYSTEM:
|
||||
You serve as an assistant that helps me play the game Minecraft.
|
||||
I will give you a goal in the game. Please think of a plan to achieve the goal, and then write a sequence of actions to realize the plan. The requirements and instructions are as follows:
|
||||
1. You can only use the following functions. Don’t make plans purely based on your experience, think about how to use these functions.
|
||||
explore(object, strategy)
|
||||
Move around to find the object with the strategy: used to find objects including block items and entities. This action is finished once the object is visible (maybe at the distance).
|
||||
Augments:
|
||||
- object: a string, the object to explore.
|
||||
- strategy: a string, the strategy for exploration.
|
||||
approach(object)
|
||||
Move close to a visible object: used to approach the object you want to attack or mine. It may fail if the target object is not accessible.
|
||||
Augments:
|
||||
- object: a string, the object to approach.
|
||||
craft(object, materials, tool)
|
||||
Craft the object with the materials and tool: used for crafting new object that is not in the inventory or is not enough. The required materials must be in the inventory and will be consumed, and the newly crafted objects will be added to the inventory. The tools like the crafting table and furnace should be in the inventory and this action will directly use them. Don’t try to place or approach the crafting table or furnace, you will get failed since this action does not support using tools placed on the ground. You don’t need to collect the items after crafting. If the quantity you require is more than a unit, this action will craft the objects one unit by one unit. If the materials run out halfway through, this action will stop, and you will only get part of the objects you want that have been crafted.
|
||||
Augments:
|
||||
- object: a dict, whose key is the name of the object and value is the object quantity.
|
||||
- materials: a dict, whose keys are the names of the materials and values are the quantities.
|
||||
- tool: a string, the tool used for crafting. Set to null if no tool is required.
|
||||
mine(object, tool)
|
||||
Mine the object with the tool: can only mine the object within reach, cannot mine object from a distance. If there are enough objects within reach, this action will mine as many as you specify. The obtained objects will be added to the inventory.
|
||||
Augments:
|
||||
- object: a string, the object to mine.
|
||||
- tool: a string, the tool used for mining. Set to null if no tool is required.
|
||||
attack(object, tool)
|
||||
Attack the object with the tool: used to attack the object within reach. This action will keep track of and attack the object until it is killed.
|
||||
Augments:
|
||||
- object: a string, the object to attack.
|
||||
- tool: a string, the tool used for mining. Set to null if no tool is required.
|
||||
equip(object)
|
||||
Equip the object from the inventory: used to equip equipment, including tools, weapons, and armor. The object must be in the inventory and belong to the items for equipping.
|
||||
Augments:
|
||||
- object: a string, the object to equip.
|
||||
digdown(object, tool)
|
||||
Dig down to the y-level with the tool: the only action you can take if you want to go underground for mining some ore.
|
||||
Augments:
|
||||
- object: an int, the y-level (absolute y coordinate) to dig to.
|
||||
- tool: a string, the tool used for digging. Set to null if no tool is required.
|
||||
go_back_to_ground(tool)
|
||||
Go back to the ground from underground: the only action you can take for going back to the ground if you are underground.
|
||||
Augments:
|
||||
- tool: a string, the tool used for digging. Set to null if no tool is required.
|
||||
apply(object, tool)
|
||||
Apply the tool on the object: used for fetching water, milk, lava with the tool bucket, pooling water or lava to the object with the tool water bucket or lava bucket, shearing sheep with the tool shears, blocking attacks with the tool shield.
|
||||
Augments:
|
||||
- object: a string, the object to apply to.
|
||||
- tool: a string, the tool used to apply.
|
||||
2. You cannot define any new function. Note that the "Generated structures" world creation option is turned off.
|
||||
3. There is an inventory that stores all the objects I have. It is not an entity, but objects can be added to it or retrieved from it anytime at anywhere without specific actions. The mined or crafted objects will be added to this inventory, and the materials and tools to use are also from this inventory. Objects in the inventory can be directly used. Don’t write the code to obtain them. If you plan to use some object not in the inventory, you should first plan to obtain it. You can view the inventory as one of my states, and it is written in form of a dictionary whose keys are the name of the objects I have and the values are their quantities.
|
||||
4. You will get the following information about my current state:
|
||||
- inventory: a dict representing the inventory mentioned above, whose keys are the name of the objects and the values are their quantities
|
||||
- environment: a string including my surrounding biome, the y-level of my current location, and whether I am on the ground or underground
|
||||
Pay attention to this information. Choose the easiest way to achieve the goal conditioned on my current state. Do not provide options, always make the final decision.
|
||||
5. You must describe your thoughts on the plan in natural language at the beginning. After that, you should write all the actions together. The response should follow the format:
|
||||
{
|
||||
"explanation": "explain why the last action failed, set to null for the first planning",
|
||||
"thoughts": "Your thoughts on the plan in natural languag",
|
||||
"action_list": [
|
||||
{"name": "action name", "args": {"arg name": value}, "expectation": "describe the expected results of this action"},
|
||||
{"name": "action name", "args": {"arg name": value}, "expectation": "describe the expected results of this action"},
|
||||
{"name": "action name", "args": {"arg name": value}, "expectation": "describe the expected results of this action"}
|
||||
]
|
||||
}
|
||||
The action_list can contain arbitrary number of actions. The args of each action should correspond to the type mentioned in the Arguments part. Remember to add “‘dict“‘ at the beginning and the end of the dict. Ensure that you response can be parsed by Python json.loads
|
||||
6. I will execute your code step by step and give you feedback. If some action fails, I will stop at that action and will not execute its following actions. The feedback will include error messages about the failed action. At that time, you should replan and write the new code just starting from that failed action.
|
||||
"""
|
||||
|
||||
|
||||
SOP_USER = """USER:
|
||||
My current state:
|
||||
- inventory: {inventory}
|
||||
- environment: {environment}
|
||||
The goal is to {goal}.
|
||||
Here is one plan to achieve similar goal for reference: {reference plan}.
|
||||
Begin your plan. Remember to follow the response format.
|
||||
or Action {successful action} succeeded, and {feedback message}. Continue your
|
||||
plan. Do not repeat successful action. Remember to follow the response format.
|
||||
or Action {failed action} failed, because {feedback message}. Revise your plan from
|
||||
the failed action. Remember to follow the response format.
|
||||
"""
|
||||
|
|
@ -6,11 +6,22 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
|
||||
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI", "OllamaGPTAPI"]
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
"GeminiLLM",
|
||||
"OpenLLM",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAILLM",
|
||||
"AzureOpenAILLM",
|
||||
"MetaGPTLLM",
|
||||
"OllamaLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@
|
|||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class Claude2:
|
||||
def ask(self, prompt):
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
|
|
@ -23,10 +23,10 @@ class Claude2:
|
|||
)
|
||||
return res.completion
|
||||
|
||||
async def aask(self, prompt):
|
||||
client = Anthropic(api_key=CONFIG.anthropic_api_key)
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key)
|
||||
|
||||
res = client.completions.create(
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
|
|
|
|||
45
metagpt/provider/azure_openai_api.py
Normal file
45
metagpt/provider/azure_openai_api.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def _init_client(self):
|
||||
kwargs = self._make_client_kwargs()
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.aclient = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
api_key=self.config.OPENAI_API_KEY,
|
||||
api_version=self.config.OPENAI_API_VERSION,
|
||||
azure_endpoint=self.config.OPENAI_BASE_URL,
|
||||
)
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:00
|
||||
@Author : alexanderwu
|
||||
@File : base_chatbot.py
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseChatbot(ABC):
|
||||
"""Abstract GPT class"""
|
||||
|
||||
mode: str = "API"
|
||||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str) -> str:
|
||||
"""Ask GPT a question and get an answer"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
"""Ask GPT multiple questions and get a series of answers"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_code(self, msgs: list) -> str:
|
||||
"""Ask GPT multiple questions and get a piece of code"""
|
||||
|
|
@ -3,19 +3,18 @@
|
|||
"""
|
||||
@Time : 2023/5/5 23:04
|
||||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
@File : base_llm.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_chatbot import BaseChatbot
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
class BaseGPTAPI(BaseChatbot):
|
||||
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
use_system_prompt: bool = True
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
@ -33,72 +32,44 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream=True) -> str:
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
if system_msgs:
|
||||
message = (
|
||||
self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
if self.use_system_prompt
|
||||
else [self._user_msg(msg)]
|
||||
)
|
||||
message = self._system_msgs(system_msgs)
|
||||
else:
|
||||
message = (
|
||||
[self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream)
|
||||
# logger.debug(rsp)
|
||||
message = [self._default_system_msg()] if self.use_system_prompt else []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
|
||||
|
||||
def ask_batch(self, msgs: list) -> str:
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp = self.completion(context)
|
||||
rsp_text = self.get_choice_text(rsp)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_batch(self, msgs: list) -> str:
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Sequential questioning"""
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp_text = await self.acompletion_text(context)
|
||||
rsp_text = await self.acompletion_text(context, timeout=timeout)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
def ask_code(self, msgs: list[str]) -> str:
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = self.ask_batch(msgs)
|
||||
return rsp_text
|
||||
|
||||
async def aask_code(self, msgs: list[str]) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs)
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
@abstractmethod
|
||||
def completion(self, messages: list[dict]):
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "hello, show me python hello world code"},
|
||||
# {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
|
|
@ -109,7 +80,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -145,7 +116,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
:return dict: return first function of choice, for exmaple,
|
||||
{'name': 'execute', 'arguments': '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}'}
|
||||
"""
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"].to_dict()
|
||||
return rsp.get("choices")[0]["message"]["tool_calls"][0]["function"]
|
||||
|
||||
def get_choice_function_arguments(self, rsp: dict) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
|
@ -155,11 +126,3 @@ class BaseGPTAPI(BaseChatbot):
|
|||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"])
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
|
@ -2,24 +2,139 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import openai
|
||||
import re
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.__init_fireworks(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.fireworks_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
self._cost_manager = FireworksCostManager()
|
||||
|
||||
def __init_fireworks(self, config: "Config"):
|
||||
openai.api_key = config.fireworks_api_key
|
||||
openai.api_base = config.fireworks_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from enum import Enum
|
|||
from typing import (
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
|
|
@ -47,8 +46,7 @@ MAX_CONNECTION_RETRIES = 2
|
|||
# Has one attribute per thread, 'session'.
|
||||
_thread_context = threading.local()
|
||||
|
||||
LLM_LOG = os.environ.get("LLM_LOG")
|
||||
LLM_LOG = "debug"
|
||||
LLM_LOG = os.environ.get("LLM_LOG", "debug")
|
||||
|
||||
|
||||
class ApiType(Enum):
|
||||
|
|
@ -101,7 +99,7 @@ def log_info(message, **params):
|
|||
def log_warn(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
print(msg, file=sys.stderr)
|
||||
logger.warn(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def logfmt(props):
|
||||
|
|
@ -241,54 +239,6 @@ class APIRequestor:
|
|||
self.api_version = api_version or openai.api_version
|
||||
self.organization = organization or openai.organization
|
||||
|
||||
def _check_polling_response(self, response: OpenAIResponse, predicate: Callable[[OpenAIResponse], bool]):
|
||||
if not predicate(response):
|
||||
return
|
||||
error_data = response.data["error"]
|
||||
message = error_data.get("message", "Operation failed")
|
||||
code = error_data.get("code")
|
||||
raise openai.APIError(message=message, body=dict(code=code))
|
||||
|
||||
def _poll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
time.sleep(delay)
|
||||
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
time.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = self.request(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
async def _apoll(
|
||||
self, method, url, until, failed, params=None, headers=None, interval=None, delay=None
|
||||
) -> Tuple[Iterator[OpenAIResponse], bool, str]:
|
||||
if delay:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
start_time = time.time()
|
||||
while not until(response):
|
||||
if time.time() - start_time > TIMEOUT_SECS:
|
||||
raise openai.APITimeoutError("Operation polling timed out.")
|
||||
|
||||
await asyncio.sleep(interval or response.retry_after or 10)
|
||||
response, b, api_key = await self.arequest(method, url, params, headers)
|
||||
self._check_polling_response(response, failed)
|
||||
|
||||
response.data = response.data["result"]
|
||||
return response, b, api_key
|
||||
|
||||
@overload
|
||||
def request(
|
||||
self,
|
||||
|
|
@ -470,55 +420,6 @@ class APIRequestor:
|
|||
await ctx.__aexit__(None, None, None)
|
||||
return resp, got_stream, self.api_key
|
||||
|
||||
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
|
||||
try:
|
||||
error_data = resp["error"]
|
||||
except (KeyError, TypeError):
|
||||
raise openai.APIError(
|
||||
"Invalid response object from API: %r (HTTP response code " "was %d)" % (rbody, rcode)
|
||||
)
|
||||
|
||||
if "internal_message" in error_data:
|
||||
error_data["message"] += "\n\n" + error_data["internal_message"]
|
||||
|
||||
log_info(
|
||||
"LLM API error received",
|
||||
error_code=error_data.get("code"),
|
||||
error_type=error_data.get("type"),
|
||||
error_message=error_data.get("message"),
|
||||
error_param=error_data.get("param"),
|
||||
stream_error=stream_error,
|
||||
)
|
||||
|
||||
# Rate limits were previously coded as 400's with code 'rate_limit'
|
||||
if rcode == 429:
|
||||
return openai.RateLimitError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif rcode in [400, 404, 415]:
|
||||
return openai.BadRequestError(
|
||||
message=f'{error_data.get("message")}, {error_data.get("param")}, {error_data.get("code")} {rbody} {rcode} {resp} {rheaders}',
|
||||
body=rbody,
|
||||
)
|
||||
elif rcode == 401:
|
||||
return openai.AuthenticationError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 403:
|
||||
return openai.PermissionDeniedError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody
|
||||
)
|
||||
elif rcode == 409:
|
||||
return openai.ConflictError(f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
elif stream_error:
|
||||
# TODO: we will soon attach status codes to stream errors
|
||||
parts = [error_data.get("message"), "(Error occurred while streaming.)"]
|
||||
message = " ".join([p for p in parts if p is not None])
|
||||
return openai.APIError(f"{message} {rbody} {rcode} {resp} {rheaders}", body=rbody)
|
||||
else:
|
||||
return openai.APIError(
|
||||
f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}",
|
||||
body=rbody,
|
||||
)
|
||||
|
||||
def request_headers(self, method: str, extra, request_id: Optional[str]) -> Dict[str, str]:
|
||||
user_agent = "LLM/v1 PythonBindings/%s" % (version.VERSION,)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
|
@ -53,13 +53,12 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
self.__init_gemini(CONFIG)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
# Not to change BaseGPTAPI default functions but update with Gemini's conversation format.
|
||||
# Not to change BaseLLM default functions but update with Gemini's conversation format.
|
||||
# You should follow the format.
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
|
|
@ -76,7 +75,7 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
|
|
@ -134,7 +133,7 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -6,32 +6,35 @@ Author: garylin2099
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class HumanProvider(BaseGPTAPI):
|
||||
class HumanProvider(BaseLLM):
|
||||
"""Humans provide themselves as a 'model', which actually takes in human input as its response.
|
||||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
logger.info("It's your turn, please type in your response. You may also refer to the context below")
|
||||
rsp = input(msg)
|
||||
if rsp in ["exit", "quit"]:
|
||||
exit()
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
return self.ask(msg)
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,
|
||||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
return ""
|
||||
|
|
|
|||
16
metagpt/provider/metagpt_api.py
Normal file
16
metagpt/provider/metagpt_api.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -16,10 +16,11 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class OllamaCostManager(CostManager):
|
||||
|
|
@ -29,16 +30,16 @@ class OllamaCostManager(CostManager):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Max budget: ${max_budget:.3f} | "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
class OllamaGPTAPI(BaseGPTAPI):
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
|
@ -53,7 +54,6 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
|
|
@ -83,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = self.client.request(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
|
|
@ -107,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
@ -143,7 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
import openai
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
|
|
@ -24,25 +26,51 @@ class OpenLLMCostManager(CostManager):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
max_budget = CONFIG.max_budget if CONFIG.max_budget else CONFIG.cost_manager.max_budget
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Max budget: ${max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.__init_openllm(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.open_llm_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self, config: "Config"):
|
||||
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
|
||||
openai.api_base = config.open_llm_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._init_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -3,21 +3,17 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
AsyncAzureOpenAI,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
import json
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from tenacity import (
|
||||
|
|
@ -30,114 +26,19 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - self.last_call_time
|
||||
|
||||
if elapsed_time < self.interval * num_requests:
|
||||
remaining_time = self.interval * num_requests - elapsed_time
|
||||
logger.info(f"sleep {remaining_time}")
|
||||
await asyncio.sleep(remaining_time)
|
||||
|
||||
self.last_call_time = time.time()
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(metaclass=Singleton):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.total_cost = 0
|
||||
self.total_budget = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
@ -150,53 +51,31 @@ See FAQ 5.8
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openai()
|
||||
self._init_openai()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self):
|
||||
self.is_azure = self.config.openai_api_type == "azure"
|
||||
self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
def _init_openai(self):
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
if self.is_azure:
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
else:
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
if self.is_azure:
|
||||
kwargs = dict(
|
||||
api_key=self.config.openai_api_key,
|
||||
api_version=self.config.openai_api_version,
|
||||
azure_endpoint=self.config.openai_base_url,
|
||||
)
|
||||
else:
|
||||
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
|
||||
|
||||
async_kwargs = kwargs.copy()
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
if proxy_params := self._get_proxy_params():
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
return kwargs
|
||||
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
|
|
@ -207,59 +86,37 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
if chunk.choices:
|
||||
chunk_message = chunk.choices[0].delta # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if chunk_message.content:
|
||||
log_llm_stream(chunk_message.content)
|
||||
print()
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
|
||||
full_reply_content = "".join([m.content for m in collected_messages if m.content])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], **configs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"timeout": 3,
|
||||
"model": self.model,
|
||||
"timeout": max(CONFIG.timeout, timeout),
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return self._chat_completion(messages)
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> ChatCompletion:
|
||||
return await self._achat_completion(messages)
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
|
|
@ -268,17 +125,26 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], **kwargs) -> dict:
|
||||
"""
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
"""
|
||||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {
|
||||
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
|
||||
|
|
@ -286,17 +152,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages, **kwargs)
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], **kwargs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], **chat_configs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**self._func_configs(messages, **chat_configs)
|
||||
)
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
|
|
@ -316,52 +176,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
)
|
||||
return messages
|
||||
|
||||
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> llm.ask_code("Write a python hello world code.")
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> llm.ask_code(msg)
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = self._chat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> rsp = await llm.ask_code("Write a python hello world code.")
|
||||
>>> rsp
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> llm = OpenAILLM()
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
@handle_exception
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
try:
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
@ -376,51 +212,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]]) -> list[ChatCompletion]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
||||
for small_batch in split_batches:
|
||||
logger.info(small_batch)
|
||||
await self.wait_if_needed(len(small_batch))
|
||||
|
||||
future = [self.acompletion(prompt) for prompt in small_batch]
|
||||
results = await asyncio.gather(*future)
|
||||
logger.info(results)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]]) -> list[str]:
|
||||
"""Only return plain text"""
|
||||
raw_results = await self.acompletion_batch(batch)
|
||||
results = []
|
||||
for idx, raw_result in enumerate(raw_results, start=1):
|
||||
result = self.get_choice_text(raw_result)
|
||||
results.append(result)
|
||||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
try:
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
return await self.async_client.moderations.create(input=content)
|
||||
"""Moderate content."""
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import (
|
|||
)
|
||||
|
||||
|
||||
class BasePostPrecessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
|
||||
class BasePostProcessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprocess`
|
||||
|
||||
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
|
||||
"""
|
||||
|
|
@ -4,17 +4,17 @@
|
|||
|
||||
from typing import Union
|
||||
|
||||
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
|
||||
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
|
||||
|
||||
|
||||
def llm_output_postprecess(
|
||||
def llm_output_postprocess(
|
||||
output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
default use BasePostPrecessPlugin if there is not matched plugin.
|
||||
default use BasePostProcessPlugin if there is not matched plugin.
|
||||
"""
|
||||
# TODO choose different model's plugin according to the model_name
|
||||
postprecess_plugin = BasePostPrecessPlugin()
|
||||
postprocess_plugin = BasePostProcessPlugin()
|
||||
|
||||
result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
return result
|
||||
|
|
@ -1,9 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
|
|
@ -13,7 +11,6 @@ import hmac
|
|||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
|
|
@ -21,47 +18,29 @@ import websocket # 使用websocket_client
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkAPI(BaseGPTAPI):
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = await self.acompletion(message)
|
||||
logger.debug(message)
|
||||
return rsp
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : zhipu model api to support sync & async for invoke & sse_invoke
|
||||
|
||||
import json
|
||||
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType, ModelAPI
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
|
@ -33,7 +35,7 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
|
|
@ -51,7 +53,6 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
params=kwargs,
|
||||
request_timeout=zhipuai.api_timeout_seconds,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -61,6 +62,8 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
resp = await cls.arequest(
|
||||
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
|
||||
)
|
||||
resp = resp.decode("utf-8")
|
||||
resp = json.loads(resp)
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
|
||||
|
|
@ -32,25 +32,25 @@ class ZhiPuEvent(Enum):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
class ZhiPuAIGPTAPI(BaseGPTAPI):
|
||||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
"""
|
||||
|
||||
use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
|
||||
def __init__(self):
|
||||
self.__init_zhipuai(CONFIG)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self._cost_manager = CostManager()
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
|
||||
def __init_zhipuai(self, config: CONFIG):
|
||||
assert config.zhipuai_api_key
|
||||
zhipuai.api_key = config.zhipuai_api_key
|
||||
openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
# due to use openai sdk, set the api_key but it will't be used.
|
||||
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
|
||||
if config.openai_proxy:
|
||||
# FIXME: openai v1.x sdk has no proxy support
|
||||
openai.proxy = config.openai_proxy
|
||||
|
||||
def _const_kwargs(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
|
|
@ -73,22 +73,22 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
assert assist_msg["role"] == "assistant"
|
||||
return assist_msg.get("content")
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.invoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
|
||||
usage = resp.get("data").get("usage")
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
@ -100,7 +100,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
|
||||
content = event.data
|
||||
logger.error(f"event error: {content}", end="")
|
||||
collected_content.append([content])
|
||||
elif event.event == ZhiPuEvent.FINISH.value:
|
||||
"""
|
||||
event.meta
|
||||
|
|
@ -131,7 +130,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -5,19 +5,47 @@
|
|||
@Author : alexanderwu
|
||||
@File : repo_parser.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class RepoFileInfo(BaseModel):
|
||||
file: str
|
||||
classes: List = Field(default_factory=list)
|
||||
functions: List = Field(default_factory=list)
|
||||
globals: List = Field(default_factory=list)
|
||||
page_info: List = Field(default_factory=list)
|
||||
|
||||
|
||||
class CodeBlockInfo(BaseModel):
|
||||
lineno: int
|
||||
end_lineno: int
|
||||
type_name: str
|
||||
tokens: List = Field(default_factory=list)
|
||||
properties: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
name: str
|
||||
package: Optional[str] = None
|
||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
||||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
|
|
@ -27,31 +55,32 @@ class RepoParser(BaseModel):
|
|||
"""Parse a Python file in the repository."""
|
||||
return ast.parse(file_path.read_text()).body
|
||||
|
||||
def extract_class_and_function_info(self, tree, file_path):
|
||||
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
|
||||
"""Extract class, function, and global variable information from the AST."""
|
||||
file_info = {
|
||||
"file": str(file_path.relative_to(self.base_directory)),
|
||||
"classes": [],
|
||||
"functions": [],
|
||||
"globals": [],
|
||||
}
|
||||
|
||||
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
|
||||
for node in tree:
|
||||
info = RepoParser.node_to_str(node)
|
||||
file_info.page_info.append(info)
|
||||
if isinstance(node, ast.ClassDef):
|
||||
class_methods = [m.name for m in node.body if is_func(m)]
|
||||
file_info["classes"].append({"name": node.name, "methods": class_methods})
|
||||
file_info.classes.append({"name": node.name, "methods": class_methods})
|
||||
elif is_func(node):
|
||||
file_info["functions"].append(node.name)
|
||||
file_info.functions.append(node.name)
|
||||
elif isinstance(node, (ast.Assign, ast.AnnAssign)):
|
||||
for target in node.targets if isinstance(node, ast.Assign) else [node.target]:
|
||||
if isinstance(target, ast.Name):
|
||||
file_info["globals"].append(target.id)
|
||||
file_info.globals.append(target.id)
|
||||
return file_info
|
||||
|
||||
def generate_symbols(self):
|
||||
def generate_symbols(self) -> List[RepoFileInfo]:
|
||||
files_classes = []
|
||||
directory = self.base_directory
|
||||
for path in directory.rglob("*.py"):
|
||||
|
||||
matching_files = []
|
||||
extensions = ["*.py", "*.js"]
|
||||
for ext in extensions:
|
||||
matching_files += directory.rglob(ext)
|
||||
for path in matching_files:
|
||||
tree = self._parse_file(path)
|
||||
file_info = self.extract_class_and_function_info(tree, path)
|
||||
files_classes.append(file_info)
|
||||
|
|
@ -60,16 +89,16 @@ class RepoParser(BaseModel):
|
|||
|
||||
def generate_json_structure(self, output_path):
|
||||
"""Generate a JSON file documenting the repository structure."""
|
||||
files_classes = self.generate_symbols()
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
output_path.write_text(json.dumps(files_classes, indent=4))
|
||||
|
||||
def generate_dataframe_structure(self, output_path):
|
||||
"""Generate a DataFrame documenting the repository structure and save as CSV."""
|
||||
files_classes = self.generate_symbols()
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
df = pd.DataFrame(files_classes)
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
def generate_structure(self, output_path=None, mode="json"):
|
||||
def generate_structure(self, output_path=None, mode="json") -> Path:
|
||||
"""Generate the structure of the repository as a specified format."""
|
||||
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
|
||||
output_path = Path(output_path) if output_path else output_file
|
||||
|
|
@ -78,22 +107,217 @@ class RepoParser(BaseModel):
|
|||
self.generate_json_structure(output_path)
|
||||
elif mode == "csv":
|
||||
self.generate_dataframe_structure(output_path)
|
||||
return output_path
|
||||
|
||||
@staticmethod
|
||||
def node_to_str(node) -> (int, int, str, str | Tuple):
|
||||
if any_to_str(node) == any_to_str(ast.Expr):
|
||||
return CodeBlockInfo(
|
||||
lineno=node.lineno,
|
||||
end_lineno=node.end_lineno,
|
||||
type_name=any_to_str(node),
|
||||
tokens=RepoParser._parse_expr(node),
|
||||
)
|
||||
mappings = {
|
||||
any_to_str(ast.Import): lambda x: [RepoParser._parse_name(n) for n in x.names],
|
||||
any_to_str(ast.Assign): RepoParser._parse_assign,
|
||||
any_to_str(ast.ClassDef): lambda x: x.name,
|
||||
any_to_str(ast.FunctionDef): lambda x: x.name,
|
||||
any_to_str(ast.ImportFrom): lambda x: {
|
||||
"module": x.module,
|
||||
"names": [RepoParser._parse_name(n) for n in x.names],
|
||||
},
|
||||
any_to_str(ast.If): RepoParser._parse_if,
|
||||
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
|
||||
}
|
||||
func = mappings.get(any_to_str(node))
|
||||
if func:
|
||||
code_block = CodeBlockInfo(lineno=node.lineno, end_lineno=node.end_lineno, type_name=any_to_str(node))
|
||||
val = func(node)
|
||||
if isinstance(val, dict):
|
||||
code_block.properties = val
|
||||
elif isinstance(val, list):
|
||||
code_block.tokens = val
|
||||
elif isinstance(val, str):
|
||||
code_block.tokens = [val]
|
||||
else:
|
||||
raise NotImplementedError(f"Not implement:{val}")
|
||||
return code_block
|
||||
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr(node) -> List:
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
|
||||
}
|
||||
func = funcs.get(any_to_str(node.value))
|
||||
if func:
|
||||
return func(node)
|
||||
raise NotImplementedError(f"Not implement: {node.value}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(n):
|
||||
if n.asname:
|
||||
return f"{n.name} as {n.asname}"
|
||||
return n.name
|
||||
|
||||
@staticmethod
|
||||
def _parse_if(n):
|
||||
tokens = [RepoParser._parse_variable(n.test.left)]
|
||||
for item in n.test.comparators:
|
||||
tokens.append(RepoParser._parse_variable(item))
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _parse_variable(node):
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
any_to_str(ast.Name): lambda x: x.id,
|
||||
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
raise NotImplementedError(f"Not implement:{node}")
|
||||
return func(node)
|
||||
|
||||
@staticmethod
|
||||
def _parse_assign(node):
|
||||
return [RepoParser._parse_variable(t) for t in node.targets]
|
||||
|
||||
async def rebuild_class_views(self, path: str | Path = None):
|
||||
if not path:
|
||||
path = self.base_directory
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
command = f"pyreverse {str(path)} -o dot"
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"{result}")
|
||||
class_view_pathname = path / "classes.dot"
|
||||
class_views = await self._parse_classes(class_view_pathname)
|
||||
packages_pathname = path / "packages.dot"
|
||||
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
|
||||
class_view_pathname.unlink(missing_ok=True)
|
||||
packages_pathname.unlink(missing_ok=True)
|
||||
return class_views
|
||||
|
||||
async def _parse_classes(self, class_view_pathname):
|
||||
class_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return class_views
|
||||
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
|
||||
lines = await reader.readlines()
|
||||
for line in lines:
|
||||
package_name, info = RepoParser._split_class_line(line)
|
||||
if not package_name:
|
||||
continue
|
||||
class_name, members, functions = re.split(r"(?<!\\)\|", info)
|
||||
class_info = ClassInfo(name=class_name)
|
||||
class_info.package = package_name
|
||||
for m in members.split("\n"):
|
||||
if not m:
|
||||
continue
|
||||
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
|
||||
class_info.attributes[member_name] = m
|
||||
for f in functions.split("\n"):
|
||||
if not f:
|
||||
continue
|
||||
function_name, _ = f.split("(", 1)
|
||||
class_info.methods[function_name] = f
|
||||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
@staticmethod
|
||||
def _split_class_line(line):
|
||||
part_splitor = '" ['
|
||||
if part_splitor not in line:
|
||||
return None, None
|
||||
ix = line.find(part_splitor)
|
||||
class_name = line[0:ix].replace('"', "")
|
||||
left = line[ix:]
|
||||
begin_flag = "label=<{"
|
||||
end_flag = "}>"
|
||||
if begin_flag not in left or end_flag not in left:
|
||||
return None, None
|
||||
bix = left.find(begin_flag)
|
||||
eix = left.rfind(end_flag)
|
||||
info = left[bix + len(begin_flag) : eix]
|
||||
info = re.sub(r"<br[^>]*>", "\n", info)
|
||||
return class_name, info
|
||||
|
||||
@staticmethod
|
||||
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
|
||||
mappings = {
|
||||
str(path).replace("/", "."): str(path),
|
||||
}
|
||||
files = []
|
||||
try:
|
||||
directory_path = Path(path)
|
||||
if not directory_path.exists():
|
||||
return mappings
|
||||
for file_path in directory_path.iterdir():
|
||||
if file_path.is_file():
|
||||
files.append(str(file_path))
|
||||
else:
|
||||
subfolder_files = RepoParser._create_path_mapping(path=file_path)
|
||||
mappings.update(subfolder_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
for f in files:
|
||||
mappings[str(Path(f).with_suffix("")).replace("/", ".")] = str(f)
|
||||
|
||||
return mappings
|
||||
|
||||
@staticmethod
|
||||
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
|
||||
if not class_views:
|
||||
return []
|
||||
c = class_views[0]
|
||||
full_key = str(path).lstrip("/").replace("/", ".")
|
||||
root_namespace = RepoParser._find_root(full_key, c.package)
|
||||
root_path = root_namespace.replace(".", "/")
|
||||
|
||||
mappings = RepoParser._create_path_mapping(path=path)
|
||||
new_mappings = {}
|
||||
ix_root_namespace = len(root_namespace)
|
||||
ix_root_path = len(root_path)
|
||||
for k, v in mappings.items():
|
||||
nk = k[ix_root_namespace:]
|
||||
nv = v[ix_root_path:]
|
||||
new_mappings[nk] = nv
|
||||
|
||||
for c in class_views:
|
||||
c.package = RepoParser._repair_ns(c.package, new_mappings)
|
||||
return class_views
|
||||
|
||||
@staticmethod
|
||||
def _repair_ns(package, mappings):
|
||||
file_ns = package
|
||||
while file_ns != "":
|
||||
if file_ns not in mappings:
|
||||
ix = file_ns.rfind(".")
|
||||
file_ns = file_ns[0:ix]
|
||||
continue
|
||||
break
|
||||
internal_ns = package[ix + 1 :]
|
||||
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
|
||||
return ns
|
||||
|
||||
@staticmethod
|
||||
def _find_root(full_key, package) -> str:
|
||||
left = full_key
|
||||
while left != "":
|
||||
if left in package:
|
||||
break
|
||||
if "." not in left:
|
||||
break
|
||||
ix = left.find(".")
|
||||
left = left[ix + 1 :]
|
||||
ix = full_key.rfind(left)
|
||||
return "." + full_key[0:ix]
|
||||
|
||||
|
||||
def is_func(node):
|
||||
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
|
||||
|
||||
def main():
|
||||
repo_parser = RepoParser(base_directory=CONFIG.workspace_path / "web_2048")
|
||||
symbols = repo_parser.generate_symbols()
|
||||
logger.info(pformat(symbols))
|
||||
|
||||
|
||||
def error():
|
||||
"""raise Exception and logs it"""
|
||||
RepoParser._parse_file(Path("test.py"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
139
metagpt/roles/assistant.py
Normal file
139
metagpt/roles/assistant.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/7
|
||||
@Author : mashenquan
|
||||
@File : assistant.py
|
||||
@Desc : I am attempting to incorporate certain symbol concepts from UML into MetaGPT, enabling it to have the
|
||||
ability to freely construct flows through symbol concatenation. Simultaneously, I am also striving to
|
||||
make these symbols configurable and standardized, making the process of building flows more convenient.
|
||||
For more about `fork` node in activity diagrams, see: `https://www.uml-diagrams.org/activity-diagrams.html`
|
||||
This file defines a `fork` style meta role capable of generating arbitrary roles at runtime based on a
|
||||
configuration file.
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false
|
||||
indicates that further reasoning cannot continue.
|
||||
|
||||
"""
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
Talk = "TALK"
|
||||
Skill = "SKILL"
|
||||
|
||||
|
||||
class Assistant(Role):
|
||||
"""Assistant for solving common issues."""
|
||||
|
||||
name: str = "Lily"
|
||||
profile: str = "An assistant"
|
||||
goal: str = "Help to solve problem"
|
||||
constraints: str = "Talk in {language}"
|
||||
desc: str = ""
|
||||
memory: BrainMemory = Field(default_factory=BrainMemory)
|
||||
skills: Optional[SkillsDeclaration] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.constraints = self.constraints.format(language=kwargs.get("language") or CONFIG.language or "Chinese")
|
||||
|
||||
async def think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
last_talk = await self.refine_memory()
|
||||
if not last_talk:
|
||||
return False
|
||||
if not self.skills:
|
||||
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
|
||||
self.skills = await SkillsDeclaration.load(skill_yaml_file_name=skill_path)
|
||||
|
||||
prompt = ""
|
||||
skills = self.skills.get_skill_list()
|
||||
for desc, name in skills.items():
|
||||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self.llm.aask(prompt, [])
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
async def act(self) -> Message:
|
||||
result = await self.rc.todo.run()
|
||||
if not result:
|
||||
return None
|
||||
if isinstance(result, str):
|
||||
msg = Message(content=result, role="assistant", cause_by=self.rc.todo)
|
||||
elif isinstance(result, Message):
|
||||
msg = result
|
||||
else:
|
||||
msg = Message(content=result.content, instruct_content=result.instruct_content, cause_by=type(self.rc.todo))
|
||||
self.memory.add_answer(msg)
|
||||
return msg
|
||||
|
||||
async def talk(self, text):
|
||||
self.memory.add_talk(Message(content=text))
|
||||
|
||||
async def _plan(self, rsp: str, **kwargs) -> bool:
|
||||
skill, text = BrainMemory.extract_info(input_string=rsp)
|
||||
handlers = {
|
||||
MessageType.Talk.value: self.talk_handler,
|
||||
MessageType.Skill.value: self.skill_handler,
|
||||
}
|
||||
handler = handlers.get(skill, self.talk_handler)
|
||||
return await handler(text, **kwargs)
|
||||
|
||||
async def talk_handler(self, text, **kwargs) -> bool:
|
||||
history = self.memory.history_text
|
||||
text = kwargs.get("last_talk") or text
|
||||
self.rc.todo = TalkAction(
|
||||
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
|
||||
)
|
||||
return True
|
||||
|
||||
async def skill_handler(self, text, **kwargs) -> bool:
|
||||
last_talk = kwargs.get("last_talk")
|
||||
skill = self.skills.get_skill(text)
|
||||
if not skill:
|
||||
logger.info(f"skill not found: {text}")
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
|
||||
await action.run(**kwargs)
|
||||
if action.args is None:
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
|
||||
return True
|
||||
|
||||
async def refine_memory(self) -> str:
|
||||
last_talk = self.memory.pop_last_talk()
|
||||
if last_talk is None: # No user feedback, unsure if past conversation is finished.
|
||||
return None
|
||||
if not self.memory.is_history_available:
|
||||
return last_talk
|
||||
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self.llm)
|
||||
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self.llm):
|
||||
# Merge relevant content.
|
||||
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self.llm)
|
||||
return f"{merged} {last_talk}"
|
||||
|
||||
return last_talk
|
||||
|
||||
def get_memory(self) -> str:
|
||||
return self.memory.model_dump_json()
|
||||
|
||||
def load_memory(self, m):
|
||||
try:
|
||||
self.memory = BrainMemory(**m)
|
||||
except Exception as e:
|
||||
logger.exception(f"load error:{e}, data:{jsn}")
|
||||
|
|
@ -7,12 +7,11 @@
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Sales
|
||||
|
||||
# from metagpt.actions import SearchAndSummarize
|
||||
# from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
DESC = """
|
||||
## Principles (all things must not bypass the principles)
|
||||
|
||||
|
|
@ -29,4 +28,4 @@ class CustomerService(Sales):
|
|||
name: str = "Xiaomei"
|
||||
profile: str = "Human customer service"
|
||||
desc: str = DESC
|
||||
store: Optional[str] = None
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from metagpt.schema import (
|
|||
Documents,
|
||||
Message,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_name, any_to_str, any_to_str_set
|
||||
|
||||
IS_PASS_PROMPT = """
|
||||
{context}
|
||||
|
|
@ -83,13 +83,17 @@ class Engineer(Role):
|
|||
n_borg: int = 1
|
||||
use_code_review: bool = False
|
||||
code_todos: list = []
|
||||
summarize_todos = []
|
||||
summarize_todos: list = []
|
||||
next_todo_action: str = ""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([WriteCode])
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
|
||||
self.code_todos = []
|
||||
self.summarize_todos = []
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
|
||||
@staticmethod
|
||||
def _parse_tasks(task_msg: Document) -> list[str]:
|
||||
|
|
@ -110,7 +114,7 @@ class Engineer(Role):
|
|||
coding_context = await todo.run(guideline=guideline)
|
||||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(context=coding_context, llm=self._llm)
|
||||
action = WriteCodeReview(context=coding_context, llm=self.llm)
|
||||
self._init_action_system_message(action)
|
||||
coding_context = await action.run()
|
||||
await src_file_repo.save(
|
||||
|
|
@ -119,9 +123,12 @@ class Engineer(Role):
|
|||
content=coding_context.code_doc.content,
|
||||
)
|
||||
msg = Message(
|
||||
content=coding_context.json(), instruct_content=coding_context, role=self.profile, cause_by=WriteCode
|
||||
content=coding_context.model_dump_json(),
|
||||
instruct_content=coding_context,
|
||||
role=self.profile,
|
||||
cause_by=WriteCode,
|
||||
)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
changed_files.add(coding_context.code_doc.filename)
|
||||
if not changed_files:
|
||||
|
|
@ -130,11 +137,13 @@ class Engineer(Role):
|
|||
|
||||
async def _act(self) -> Message | None:
|
||||
"""Determines the mode of action based on whether code review is used."""
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
return None
|
||||
if isinstance(self._rc.todo, WriteCode):
|
||||
if isinstance(self.rc.todo, WriteCode):
|
||||
self.next_todo_action = any_to_name(SummarizeCode)
|
||||
return await self._act_write_code()
|
||||
if isinstance(self._rc.todo, SummarizeCode):
|
||||
if isinstance(self.rc.todo, SummarizeCode):
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
return await self._act_summarize()
|
||||
return None
|
||||
|
||||
|
|
@ -173,7 +182,7 @@ class Engineer(Role):
|
|||
tasks.append(todo.context.dict())
|
||||
await code_summaries_file_repo.save(
|
||||
filename=Path(todo.context.design_filename).name,
|
||||
content=todo.context.json(),
|
||||
content=todo.context.model_dump_json(),
|
||||
dependencies=dependencies,
|
||||
)
|
||||
else:
|
||||
|
|
@ -196,7 +205,7 @@ class Engineer(Role):
|
|||
)
|
||||
|
||||
async def _is_pass(self, summary) -> (str, str):
|
||||
rsp = await self._llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
|
||||
rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
|
||||
logger.info(rsp)
|
||||
if "YES" in rsp:
|
||||
return True, rsp
|
||||
|
|
@ -207,17 +216,17 @@ class Engineer(Role):
|
|||
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if not self._rc.news:
|
||||
if not self.rc.news:
|
||||
return None
|
||||
msg = self._rc.news[0]
|
||||
msg = self.rc.news[0]
|
||||
if msg.cause_by in write_code_filters:
|
||||
logger.debug(f"TODO WriteCode:{msg.json()}")
|
||||
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
|
||||
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
|
||||
return self._rc.todo
|
||||
return self.rc.todo
|
||||
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
|
||||
logger.debug(f"TODO SummarizeCode:{msg.json()}")
|
||||
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
|
||||
await self._new_summarize_actions()
|
||||
return self._rc.todo
|
||||
return self.rc.todo
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -235,7 +244,9 @@ class Engineer(Role):
|
|||
task_doc = await task_file_repo.get(i.name)
|
||||
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
|
||||
design_doc = await design_file_repo.get(i.name)
|
||||
# FIXME: design doc没有加载进来,是None
|
||||
if not task_doc or not design_doc:
|
||||
logger.error(f'Detected source code "{filename}" from an unknown origin.')
|
||||
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
|
||||
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
|
||||
return context
|
||||
|
||||
|
|
@ -244,7 +255,9 @@ class Engineer(Role):
|
|||
context = await Engineer._new_coding_context(
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
)
|
||||
coding_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content=context.json())
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
|
||||
)
|
||||
return coding_doc
|
||||
|
||||
async def _new_code_actions(self, bug_fix=False):
|
||||
|
|
@ -269,15 +282,15 @@ class Engineer(Role):
|
|||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.json()
|
||||
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
|
||||
)
|
||||
if task_filename in changed_files.docs:
|
||||
logger.warning(
|
||||
f"Log to expose potential conflicts: {coding_doc.json()} & "
|
||||
f"{changed_files.docs[task_filename].json()}"
|
||||
f"Log to expose potential conflicts: {coding_doc.model_dump_json()} & "
|
||||
f"{changed_files.docs[task_filename].model_dump_json()}"
|
||||
)
|
||||
changed_files.docs[task_filename] = coding_doc
|
||||
self.code_todos = [WriteCode(context=i, llm=self._llm) for i in changed_files.docs.values()]
|
||||
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
|
||||
# Code directly modified by the user.
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
for filename in changed_src_files:
|
||||
|
|
@ -291,10 +304,10 @@ class Engineer(Role):
|
|||
dependency=dependency,
|
||||
)
|
||||
changed_files.docs[filename] = coding_doc
|
||||
self.code_todos.append(WriteCode(context=coding_doc, llm=self._llm))
|
||||
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
|
||||
|
||||
if self.code_todos:
|
||||
self._rc.todo = self.code_todos[0]
|
||||
self.rc.todo = self.code_todos[0]
|
||||
|
||||
async def _new_summarize_actions(self):
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
|
|
@ -307,9 +320,14 @@ class Engineer(Role):
|
|||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
ctx.codes_filenames = filenames
|
||||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm))
|
||||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
|
||||
if self.summarize_todos:
|
||||
self._rc.todo = self.summarize_todos[0]
|
||||
self.rc.todo = self.summarize_todos[0]
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.next_todo_action
|
||||
|
||||
async def _write_code_guideline(self):
|
||||
logger.info("Writing code guideline..")
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ class InvoiceOCRAssistant(Role):
|
|||
Returns:
|
||||
A message containing the result of the action.
|
||||
"""
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
todo = self._rc.todo
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
todo = self.rc.todo
|
||||
if isinstance(todo, InvoiceOCR):
|
||||
self.origin_query = msg.content
|
||||
invoice_path: InvoicePath = msg.instruct_content
|
||||
|
|
@ -87,11 +87,11 @@ class InvoiceOCRAssistant(Role):
|
|||
else:
|
||||
self._init_actions([GenerateTable])
|
||||
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
content = INVOICE_OCR_SUCCESS
|
||||
resp = OCRResults(ocr_result=json.dumps(resp))
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
return await super().react()
|
||||
elif isinstance(todo, GenerateTable):
|
||||
ocr_results: OCRResults = msg.instruct_content
|
||||
|
|
@ -108,5 +108,5 @@ class InvoiceOCRAssistant(Role):
|
|||
resp = ReplyData(content=resp)
|
||||
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
|
||||
"""
|
||||
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.utils.common import any_to_name
|
||||
|
||||
|
||||
class ProductManager(Role):
|
||||
|
|
@ -29,20 +29,29 @@ class ProductManager(Role):
|
|||
profile: str = "Product Manager"
|
||||
goal: str = "efficiently create a successful product that meets market demands and user expectations"
|
||||
constraints: str = "utilize the same language as the user requirements for seamless communication"
|
||||
todo_action: str = ""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([PrepareDocuments, WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
self.todo_action = any_to_name(PrepareDocuments)
|
||||
|
||||
async def _think(self) -> None:
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo:
|
||||
if CONFIG.git_repo and not CONFIG.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
return self._rc.todo
|
||||
CONFIG.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self.rc.todo)
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
return await super()._observe(ignore_memory=True)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
return self.todo_action
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
context = await WriteTest(context=context, llm=self._llm).run()
|
||||
context = await WriteTest(context=context, llm=self.llm).run()
|
||||
await tests_file_repo.save(
|
||||
filename=context.test_doc.filename,
|
||||
content=context.test_doc.content,
|
||||
|
|
@ -86,7 +86,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=WriteTest,
|
||||
sent_from=self,
|
||||
|
|
@ -106,11 +106,11 @@ class QaEngineer(Role):
|
|||
return
|
||||
run_code_context.code = src_doc.content
|
||||
run_code_context.test_code = test_doc.content
|
||||
result = await RunCode(context=run_code_context, llm=self._llm).run()
|
||||
result = await RunCode(context=run_code_context, llm=self.llm).run()
|
||||
run_code_context.output_filename = run_code_context.test_filename + ".json"
|
||||
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
|
||||
filename=run_code_context.output_filename,
|
||||
content=result.json(),
|
||||
content=result.model_dump_json(),
|
||||
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
|
||||
)
|
||||
run_code_context.code = None
|
||||
|
|
@ -120,7 +120,7 @@ class QaEngineer(Role):
|
|||
mappings = {"Engineer": "Alex", "QaEngineer": "Edward"}
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=RunCode,
|
||||
sent_from=self,
|
||||
|
|
@ -130,14 +130,14 @@ class QaEngineer(Role):
|
|||
|
||||
async def _debug_error(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
code = await DebugError(context=run_code_context, llm=self._llm).run()
|
||||
code = await DebugError(context=run_code_context, llm=self.llm).run()
|
||||
await FileRepository.save_file(
|
||||
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
|
||||
)
|
||||
run_code_context.output = None
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=DebugError,
|
||||
sent_from=self,
|
||||
|
|
@ -159,7 +159,7 @@ class QaEngineer(Role):
|
|||
code_filters = any_to_str_set({SummarizeCode})
|
||||
test_filters = any_to_str_set({WriteTest, DebugError})
|
||||
run_filters = any_to_str_set({RunCode})
|
||||
for msg in self._rc.news:
|
||||
for msg in self.rc.news:
|
||||
# Decide what to do based on observed msg type, currently defined by human,
|
||||
# might potentially be moved to _think, that is, let the agent decides for itself
|
||||
if msg.cause_by in code_filters:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of
|
||||
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
|
||||
"""
|
||||
|
|
@ -40,10 +41,21 @@ class Researcher(Role):
|
|||
if self.language not in ("en-us", "zh-cn"):
|
||||
logger.warning(f"The language `{self.language}` has not been tested, it may not work.")
|
||||
|
||||
async def _think(self) -> bool:
|
||||
if self.rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self.rc.state + 1 < len(self.states):
|
||||
self._set_state(self.rc.state + 1)
|
||||
else:
|
||||
self.rc.todo = None
|
||||
return False
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
todo = self._rc.todo
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
todo = self.rc.todo
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
if isinstance(msg.instruct_content, Report):
|
||||
instruct_content = msg.instruct_content
|
||||
topic = instruct_content.topic
|
||||
|
|
@ -67,14 +79,14 @@ class Researcher(Role):
|
|||
else:
|
||||
summaries = instruct_content.summaries
|
||||
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
|
||||
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
content = await self.rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
ret = Message(
|
||||
content="",
|
||||
instruct_content=Report(topic=topic, content=content),
|
||||
role=self.profile,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
)
|
||||
self._rc.memory.add(ret)
|
||||
self.rc.memory.add(ret)
|
||||
return ret
|
||||
|
||||
def research_system_text(self, topic, current_task: Action) -> str:
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
@Time : 2023/5/11 14:42
|
||||
@Author : alexanderwu
|
||||
@File : role.py
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116:
|
||||
1. Merge the `recv` functionality into the `_observe` function. Future message reading operations will be
|
||||
consolidated within the `_observe` function.
|
||||
2. Standardize the message filtering for string label matching. Role objects can access the message labels
|
||||
they've subscribed to through the `subscribed_tags` property.
|
||||
3. Move the message receive buffer from the global variable `self._rc.env.memory` to the role's private variable
|
||||
`self._rc.msg_buffer` for easier message identification and asynchronous appending of messages.
|
||||
3. Move the message receive buffer from the global variable `self.rc.env.memory` to the role's private variable
|
||||
`self.rc.msg_buffer` for easier message identification and asynchronous appending of messages.
|
||||
4. Standardize the way messages are passed: `publish_message` sends messages out, while `put_message` places
|
||||
messages into the Role object's private message receive buffer. There are no other message transmit methods.
|
||||
5. Standardize the parameters for the `run` function: the `test_message` parameter is used for testing purposes
|
||||
|
|
@ -23,21 +24,21 @@ from __future__ import annotations
|
|||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Set, Type
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action import action_subclass_registry
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
import_class,
|
||||
read_json_file,
|
||||
|
|
@ -90,8 +91,10 @@ class RoleReactMode(str, Enum):
|
|||
class RoleContext(BaseModel):
|
||||
"""Role Runtime Context"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
|
||||
env: "Environment" = Field(default=None, exclude=True)
|
||||
env: "Environment" = Field(default=None, exclude=True) # # avoid circular import
|
||||
# TODO judge if ser&deser
|
||||
msg_buffer: MessageQueue = Field(
|
||||
default_factory=MessageQueue, exclude=True
|
||||
|
|
@ -107,9 +110,6 @@ class RoleContext(BaseModel):
|
|||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
|
|
@ -118,7 +118,7 @@ class RoleContext(BaseModel):
|
|||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""Get the information corresponding to the watched actions"""
|
||||
"""Retrieve information corresponding to the attention action."""
|
||||
return self.memory.get_by_actions(self.watch)
|
||||
|
||||
@property
|
||||
|
|
@ -126,12 +126,11 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
role_subclass_registry = {}
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
class Role(SerializationMixin, is_polymorphic_base=True):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
profile: str = ""
|
||||
goal: str = ""
|
||||
|
|
@ -139,84 +138,40 @@ class Role(BaseModel):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
_llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_role_id: str = ""
|
||||
_states: list[str] = []
|
||||
_actions: list[Action] = []
|
||||
_rc: RoleContext = Field(default_factory=RoleContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
subscription: set[str] = set()
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
latest_observed_msg: Message = None # record the latest observed message when interrupted
|
||||
builtin_class_name: str = ""
|
||||
|
||||
_private_attributes = {
|
||||
"_llm": None,
|
||||
"_role_id": _role_id,
|
||||
"_states": [],
|
||||
"_actions": [],
|
||||
"_rc": RoleContext(),
|
||||
"_subscription": set(),
|
||||
}
|
||||
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
|
||||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
exclude = ["_llm"]
|
||||
@model_validator(mode="after")
|
||||
def check_subscription(self):
|
||||
if not self.subscription:
|
||||
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for index in range(len(kwargs.get("_actions", []))):
|
||||
current_action = kwargs["_actions"][index]
|
||||
if isinstance(current_action, dict):
|
||||
item_class_name = current_action.get("builtin_class_name", None)
|
||||
for name, subclass in action_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_action = subclass(**current_action)
|
||||
break
|
||||
kwargs["_actions"][index] = current_action
|
||||
def __init__(self, **data: Any):
|
||||
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
|
||||
from metagpt.environment import Environment
|
||||
|
||||
super().__init__(**kwargs)
|
||||
Environment
|
||||
# ------
|
||||
Role.model_rebuild()
|
||||
super().__init__(**data)
|
||||
|
||||
# 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655
|
||||
self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider()
|
||||
self._private_attributes["_role_id"] = str(self._setting)
|
||||
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
if key == "_rc":
|
||||
_rc = RoleContext(**kwargs["_rc"])
|
||||
object.__setattr__(self, "_rc", _rc)
|
||||
else:
|
||||
if key == "_rc":
|
||||
# # Warning, if use self._private_attributes["_rc"],
|
||||
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
|
||||
object.__setattr__(self, key, RoleContext())
|
||||
else:
|
||||
object.__setattr__(self, key, self._private_attributes[key])
|
||||
|
||||
self._llm.system_prompt = self._get_prefix()
|
||||
|
||||
# deserialize child classes dynamically for inherited `role`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
if "actions" in kwargs:
|
||||
self._init_actions(kwargs["actions"])
|
||||
|
||||
self._watch(kwargs.get("watch") or [UserRequirement])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
role_subclass_registry[cls.__name__] = cls
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
def _reset(self):
|
||||
object.__setattr__(self, "_states", [])
|
||||
object.__setattr__(self, "_actions", [])
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
||||
@property
|
||||
def _setting(self):
|
||||
|
|
@ -229,12 +184,12 @@ class Role(BaseModel):
|
|||
else stg_path
|
||||
)
|
||||
|
||||
role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True})
|
||||
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
|
||||
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self._rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
self.rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
|
|
@ -258,13 +213,13 @@ class Role(BaseModel):
|
|||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def refresh_system_message(self):
|
||||
self._llm.system_prompt = self._get_prefix()
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self._rc.memory = memory
|
||||
self.rc.memory = memory
|
||||
|
||||
def init_actions(self, actions):
|
||||
self._init_actions(actions)
|
||||
|
|
@ -274,7 +229,7 @@ class Role(BaseModel):
|
|||
for idx, action in enumerate(actions):
|
||||
if not isinstance(action, Action):
|
||||
## 默认初始化
|
||||
i = action(name="", llm=self._llm)
|
||||
i = action(name="", llm=self.llm)
|
||||
else:
|
||||
if self.is_human and not isinstance(action.llm, HumanProvider):
|
||||
logger.warning(
|
||||
|
|
@ -283,10 +238,9 @@ class Role(BaseModel):
|
|||
f"try passing in Action classes instead of initialized instances"
|
||||
)
|
||||
i = action
|
||||
# i.set_env(self._rc.env)
|
||||
self._init_action_system_message(i)
|
||||
self._actions.append(i)
|
||||
self._states.append(f"{idx}. {action}")
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{idx}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
@ -305,17 +259,20 @@ class Role(BaseModel):
|
|||
Defaults to 1, i.e. _think -> _act (-> return result and end)
|
||||
"""
|
||||
assert react_mode in RoleReactMode.values(), f"react_mode must be one of {RoleReactMode.values()}"
|
||||
self._rc.react_mode = react_mode
|
||||
self.rc.react_mode = react_mode
|
||||
if react_mode == RoleReactMode.REACT:
|
||||
self._rc.max_react_loop = max_react_loop
|
||||
self.rc.max_react_loop = max_react_loop
|
||||
|
||||
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
|
||||
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
|
||||
buffer during _observe.
|
||||
"""
|
||||
self._rc.watch = {any_to_str(t) for t in actions}
|
||||
self.rc.watch = {any_to_str(t) for t in actions}
|
||||
# check RoleContext after adding watch actions
|
||||
self._rc.check(self._role_id)
|
||||
self.rc.check(self.role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self.rc.watch
|
||||
|
||||
def subscribe(self, tags: Set[str]):
|
||||
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
|
||||
|
|
@ -323,23 +280,28 @@ class Role(BaseModel):
|
|||
or profile.
|
||||
"""
|
||||
self.subscription = tags
|
||||
if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
self._rc.env.set_subscription(self, self.subscription)
|
||||
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
self.rc.env.set_subscription(self, self.subscription)
|
||||
|
||||
def _set_state(self, state: int):
|
||||
"""Update the current state."""
|
||||
self._rc.state = state
|
||||
logger.debug(f"actions={self._actions}, state={state}")
|
||||
self._rc.todo = self._actions[self._rc.state] if state >= 0 else None
|
||||
self.rc.state = state
|
||||
logger.debug(f"actions={self.actions}, state={state}")
|
||||
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
|
||||
|
||||
def set_env(self, env: "Environment"):
|
||||
"""Set the environment in which the role works. The role can talk to the environment and can also receive
|
||||
messages by observing."""
|
||||
self._rc.env = env
|
||||
self.rc.env = env
|
||||
if env:
|
||||
env.set_subscription(self, self.subscription)
|
||||
self.refresh_system_message() # add env message to system message
|
||||
|
||||
@property
|
||||
def action_count(self):
|
||||
"""Return number of action"""
|
||||
return len(self.actions)
|
||||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
if self.desc:
|
||||
|
|
@ -350,36 +312,38 @@ class Role(BaseModel):
|
|||
if self.constraints:
|
||||
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
|
||||
|
||||
if self._rc.env and self._rc.env.desc:
|
||||
other_role_names = ", ".join(self._rc.env.role_names())
|
||||
env_desc = f"You are in {self._rc.env.desc} with roles({other_role_names})."
|
||||
if self.rc.env and self.rc.env.desc:
|
||||
other_role_names = ", ".join(self.rc.env.role_names())
|
||||
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
|
||||
prefix += env_desc
|
||||
return prefix
|
||||
|
||||
async def _think(self) -> None:
|
||||
"""Think about what to do and decide on the next action"""
|
||||
if len(self._actions) == 1:
|
||||
async def _think(self) -> bool:
|
||||
"""Consider what to do and decide on the next course of action. Return false if nothing can be done."""
|
||||
if len(self.actions) == 1:
|
||||
# If there is only one action, then only this one can be performed
|
||||
self._set_state(0)
|
||||
return
|
||||
if self.recovered and self._rc.state >= 0:
|
||||
self._set_state(self._rc.state) # action to run from recovered state
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return
|
||||
|
||||
return True
|
||||
|
||||
if self.recovered and self.rc.state >= 0:
|
||||
self._set_state(self.rc.state) # action to run from recovered state
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
prompt += STATE_TEMPLATE.format(
|
||||
history=self._rc.history,
|
||||
states="\n".join(self._states),
|
||||
n_states=len(self._states) - 1,
|
||||
previous_state=self._rc.state,
|
||||
history=self.rc.history,
|
||||
states="\n".join(self.states),
|
||||
n_states=len(self.states) - 1,
|
||||
previous_state=self.rc.state,
|
||||
)
|
||||
|
||||
next_state = await self._llm.aask(prompt)
|
||||
next_state = await self.llm.aask(prompt)
|
||||
next_state = extract_state_value_from_output(next_state)
|
||||
logger.debug(f"{prompt=}")
|
||||
|
||||
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)):
|
||||
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self.states)):
|
||||
logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1")
|
||||
next_state = -1
|
||||
else:
|
||||
|
|
@ -387,73 +351,66 @@ class Role(BaseModel):
|
|||
if next_state == -1:
|
||||
logger.info(f"End actions with {next_state=}")
|
||||
self._set_state(next_state)
|
||||
return True
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
response = await self._rc.todo.run(self._rc.history)
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
response = await self.rc.todo.run(self.rc.history)
|
||||
if isinstance(response, (ActionOutput, ActionNode)):
|
||||
msg = Message(
|
||||
content=response.content,
|
||||
instruct_content=response.instruct_content,
|
||||
role=self._setting,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
sent_from=self,
|
||||
)
|
||||
elif isinstance(response, Message):
|
||||
msg = response
|
||||
else:
|
||||
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo, sent_from=self)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo, sent_from=self)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
return msg
|
||||
|
||||
def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
|
||||
news = []
|
||||
# Warning, remove `id` here to make it work for recover
|
||||
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
|
||||
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
|
||||
for idx, new in enumerate(observed_pure):
|
||||
if (new["cause_by"] in self._rc.watch or self.name in new["send_to"]) and new not in existed_pure:
|
||||
news.append(observed[idx])
|
||||
return news
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
"""Prepare new messages for processing from the message buffer and other sources."""
|
||||
# Read unprocessed messages from the msg buffer.
|
||||
news = self._rc.msg_buffer.pop_all()
|
||||
news = []
|
||||
if self.recovered:
|
||||
news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
else:
|
||||
self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
|
||||
if not news:
|
||||
news = self.rc.msg_buffer.pop_all()
|
||||
# Store the read messages in your own memory to prevent duplicate processing.
|
||||
old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
self._rc.memory.add_batch(news)
|
||||
old_messages = [] if ignore_memory else self.rc.memory.get()
|
||||
self.rc.memory.add_batch(news)
|
||||
# Filter out messages of interest.
|
||||
self._rc.news = self._find_news(news, old_messages)
|
||||
self.rc.news = [
|
||||
n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages
|
||||
]
|
||||
self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg
|
||||
|
||||
# Design Rules:
|
||||
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
# msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
|
||||
news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
|
||||
news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
|
||||
if news_text:
|
||||
logger.debug(f"{self._setting} observed: {news_text}")
|
||||
return len(self._rc.news)
|
||||
return len(self.rc.news)
|
||||
|
||||
def publish_message(self, msg):
|
||||
"""If the role belongs to env, then the role's messages will be broadcast to env"""
|
||||
if not msg:
|
||||
return
|
||||
if not self._rc.env:
|
||||
if not self.rc.env:
|
||||
# If env does not exist, do not publish the message
|
||||
return
|
||||
self._rc.env.publish_message(msg)
|
||||
self.rc.env.publish_message(msg)
|
||||
|
||||
def put_message(self, message):
|
||||
"""Place the message into the Role object's private message buffer."""
|
||||
if not message:
|
||||
return
|
||||
self._rc.msg_buffer.push(message)
|
||||
self.rc.msg_buffer.push(message)
|
||||
|
||||
async def _react(self) -> Message:
|
||||
"""Think first, then act, until the Role _think it is time to stop and requires no more todo.
|
||||
|
|
@ -462,22 +419,22 @@ class Role(BaseModel):
|
|||
"""
|
||||
actions_taken = 0
|
||||
rsp = Message(content="No actions taken yet") # will be overwritten after Role _act
|
||||
while actions_taken < self._rc.max_react_loop:
|
||||
while actions_taken < self.rc.max_react_loop:
|
||||
# think
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
break
|
||||
# act
|
||||
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
rsp = await self._act() # 这个rsp是否需要publish_message?
|
||||
actions_taken += 1
|
||||
return rsp # return output from the last action
|
||||
|
||||
async def _act_by_order(self) -> Message:
|
||||
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
|
||||
start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state
|
||||
rsp = Message(content="No actions taken yet") # return default message if _actions=[]
|
||||
for i in range(start_idx, len(self._states)):
|
||||
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
|
||||
rsp = Message(content="No actions taken yet") # return default message if actions=[]
|
||||
for i in range(start_idx, len(self.states)):
|
||||
self._set_state(i)
|
||||
rsp = await self._act()
|
||||
return rsp # return output from the last action
|
||||
|
|
@ -489,35 +446,18 @@ class Role(BaseModel):
|
|||
|
||||
async def react(self) -> Message:
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message"""
|
||||
if self._rc.react_mode == RoleReactMode.REACT:
|
||||
if self.rc.react_mode == RoleReactMode.REACT:
|
||||
rsp = await self._react()
|
||||
elif self._rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
rsp = await self._act_by_order()
|
||||
elif self._rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
rsp = await self._plan_and_act()
|
||||
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
|
||||
return rsp
|
||||
|
||||
# # Replaced by run()
|
||||
# def recv(self, message: Message) -> None:
|
||||
# """add message to history."""
|
||||
# # self._history += f"\n{message}"
|
||||
# # self._context = self._history
|
||||
# if message in self._rc.memory.get():
|
||||
# return
|
||||
# self._rc.memory.add(message)
|
||||
|
||||
# # Replaced by run()
|
||||
# async def handle(self, message: Message) -> Message:
|
||||
# """Receive information and reply with actions"""
|
||||
# # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}")
|
||||
# self.recv(message)
|
||||
#
|
||||
# return await self._react()
|
||||
|
||||
def get_memories(self, k=0) -> list[Message]:
|
||||
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
|
||||
return self._rc.memory.get(k=k)
|
||||
return self.rc.memory.get(k=k)
|
||||
|
||||
@role_raise_decorator
|
||||
async def run(self, with_message=None) -> Message | None:
|
||||
|
|
@ -542,7 +482,7 @@ class Role(BaseModel):
|
|||
rsp = await self.react()
|
||||
|
||||
# Reset the next action to be taken.
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
# Send the response message to the Environment object to have it relay the message to the subscribers.
|
||||
self.publish_message(rsp)
|
||||
return rsp
|
||||
|
|
@ -550,4 +490,21 @@ class Role(BaseModel):
|
|||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""If true, all actions have been executed."""
|
||||
return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty()
|
||||
return not self.rc.news and not self.rc.todo and self.rc.msg_buffer.empty()
|
||||
|
||||
async def think(self) -> Action:
|
||||
"""The exported `think` function"""
|
||||
await self._think()
|
||||
return self.rc.todo
|
||||
|
||||
async def act(self) -> ActionOutput:
|
||||
"""The exported `act` function"""
|
||||
msg = await self._act()
|
||||
return ActionOutput(content=msg.content, instruct_content=msg.instruct_content)
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
if self.actions:
|
||||
return any_to_name(self.actions[0])
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -15,16 +17,17 @@ from metagpt.tools import SearchEngineType
|
|||
|
||||
|
||||
class Sales(Role):
|
||||
name: str = "Xiaomei"
|
||||
profile: str = "Retail sales guide"
|
||||
desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
|
||||
"will answer questions only based on the information in the knowledge base."
|
||||
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
|
||||
" I don't know, and I won't tell you that this is from the knowledge base,"
|
||||
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
|
||||
"professional guide"
|
||||
name: str = "John Smith"
|
||||
profile: str = "Retail Sales Guide"
|
||||
desc: str = (
|
||||
"As a Retail Sales Guide, my name is John Smith. I specialize in addressing customer inquiries with "
|
||||
"expertise and precision. My responses are based solely on the information available in our knowledge"
|
||||
" base. In instances where your query extends beyond this scope, I'll honestly indicate my inability "
|
||||
"to provide an answer, rather than speculate or assume. Please note, each of my replies will be "
|
||||
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
|
||||
)
|
||||
|
||||
store: Optional[BaseStore] = None
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -57,19 +57,19 @@ class Searcher(Role):
|
|||
|
||||
async def _act_sp(self) -> Message:
|
||||
"""Performs the search action in a single process."""
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
response = await self._rc.todo.run(self._rc.memory.get(k=0))
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
response = await self.rc.todo.run(self.rc.memory.get(k=0))
|
||||
|
||||
if isinstance(response, (ActionOutput, ActionNode)):
|
||||
msg = Message(
|
||||
content=response.content,
|
||||
instruct_content=response.instruct_content,
|
||||
role=self.profile,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
)
|
||||
else:
|
||||
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
|
|||
|
|
@ -7,19 +7,19 @@
|
|||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message filtering.
|
||||
"""
|
||||
from typing import Any, Type
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import Field
|
||||
from semantic_kernel import Kernel
|
||||
from semantic_kernel.planning import SequentialPlanner
|
||||
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
|
||||
from semantic_kernel.planning.basic_planner import BasicPlanner
|
||||
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.execute_task import ExecuteTask
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.make_sk_kernel import make_sk_kernel
|
||||
|
|
@ -41,17 +41,17 @@ class SkAgent(Role):
|
|||
goal: str = "Execute task based on passed in task description"
|
||||
constraints: str = ""
|
||||
|
||||
plan: Any = None
|
||||
plan: Plan = Field(default=None, exclude=True)
|
||||
planner_cls: Any = None
|
||||
planner: Any = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
|
||||
import_skill: Type[Kernel.import_skill] = None
|
||||
import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True)
|
||||
import_skill: Callable = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes the Engineer role with given attributes."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(**data)
|
||||
self._init_actions([ExecuteTask()])
|
||||
self._watch([UserRequirement])
|
||||
self.kernel = make_sk_kernel()
|
||||
|
|
@ -71,10 +71,10 @@ class SkAgent(Role):
|
|||
self._set_state(0)
|
||||
# how funny the interface is inconsistent
|
||||
if isinstance(self.planner, BasicPlanner):
|
||||
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content, self.kernel)
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
|
||||
logger.info(self.plan.generated_plan)
|
||||
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
|
||||
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content)
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
# how funny the interface is inconsistent
|
||||
|
|
@ -85,6 +85,6 @@ class SkAgent(Role):
|
|||
result = (await self.plan.invoke_async()).result
|
||||
logger.info(result)
|
||||
|
||||
msg = Message(content=result, role=self.profile, cause_by=self._rc.todo)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
|
|
|||
118
metagpt/roles/teacher.py
Normal file
118
metagpt/roles/teacher.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/27
|
||||
@Author : mashenquan
|
||||
@File : teacher.py
|
||||
@Desc : Used by Agent Store
|
||||
@Modified By: mashenquan, 2023/8/22. A definition has been provided for the return value of _think: returning false indicates that further reasoning cannot continue.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.write_teaching_plan import TeachingPlanBlock, WriteTeachingPlanPart
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class Teacher(Role):
|
||||
"""Support configurable teacher roles,
|
||||
with native and teaching languages being replaceable through configurations."""
|
||||
|
||||
name: str = "Lily"
|
||||
profile: str = "{teaching_language} Teacher"
|
||||
goal: str = "writing a {language} teaching plan part by part"
|
||||
constraints: str = "writing in {language}"
|
||||
desc: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = WriteTeachingPlanPart.format_value(self.name)
|
||||
self.profile = WriteTeachingPlanPart.format_value(self.profile)
|
||||
self.goal = WriteTeachingPlanPart.format_value(self.goal)
|
||||
self.constraints = WriteTeachingPlanPart.format_value(self.constraints)
|
||||
self.desc = WriteTeachingPlanPart.format_value(self.desc)
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
if not self.actions:
|
||||
if not self.rc.news or self.rc.news[0].cause_by != any_to_str(UserRequirement):
|
||||
raise ValueError("Lesson content invalid.")
|
||||
actions = []
|
||||
print(TeachingPlanBlock.TOPICS)
|
||||
for topic in TeachingPlanBlock.TOPICS:
|
||||
act = WriteTeachingPlanPart(context=self.rc.news[0].content, topic=topic, llm=self.llm)
|
||||
actions.append(act)
|
||||
self._init_actions(actions)
|
||||
|
||||
if self.rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self.rc.state + 1 < len(self.states):
|
||||
self._set_state(self.rc.state + 1)
|
||||
return True
|
||||
|
||||
self.rc.todo = None
|
||||
return False
|
||||
|
||||
async def _react(self) -> Message:
|
||||
ret = Message(content="")
|
||||
while True:
|
||||
await self._think()
|
||||
if self.rc.todo is None:
|
||||
break
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
msg = await self._act()
|
||||
if ret.content != "":
|
||||
ret.content += "\n\n\n"
|
||||
ret.content += msg.content
|
||||
logger.info(ret.content)
|
||||
await self.save(ret.content)
|
||||
return ret
|
||||
|
||||
async def save(self, content):
|
||||
"""Save teaching plan"""
|
||||
filename = Teacher.new_file_name(self.course_title)
|
||||
pathname = CONFIG.workspace_path / "teaching_plan"
|
||||
pathname.mkdir(exist_ok=True)
|
||||
pathname = pathname / filename
|
||||
try:
|
||||
async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Save failed:{e}")
|
||||
logger.info(f"Save to:{pathname}")
|
||||
|
||||
@staticmethod
|
||||
def new_file_name(lesson_title, ext=".md"):
|
||||
"""Create a related file name based on `lesson_title` and `ext`."""
|
||||
# Define the special characters that need to be replaced.
|
||||
illegal_chars = r'[#@$%!*&\\/:*?"<>|\n\t \']'
|
||||
# Replace the special characters with underscores.
|
||||
filename = re.sub(illegal_chars, "_", lesson_title) + ext
|
||||
return re.sub(r"_+", "_", filename)
|
||||
|
||||
@property
|
||||
def course_title(self):
|
||||
"""Return course title of teaching plan"""
|
||||
default_title = "teaching_plan"
|
||||
for act in self.actions:
|
||||
if act.topic != TeachingPlanBlock.COURSE_TITLE:
|
||||
continue
|
||||
if act.rsp is None:
|
||||
return default_title
|
||||
title = act.rsp.lstrip("# \n")
|
||||
if "\n" in title:
|
||||
ix = title.index("\n")
|
||||
title = title[0:ix]
|
||||
return title
|
||||
|
||||
return default_title
|
||||
|
|
@ -34,9 +34,9 @@ class TutorialAssistant(Role):
|
|||
constraints: str = "Strictly follow Markdown's syntax, with neat and standardized layout"
|
||||
language: str = "Chinese"
|
||||
|
||||
topic = ""
|
||||
main_title = ""
|
||||
total_content = ""
|
||||
topic: str = ""
|
||||
main_title: str = ""
|
||||
total_content: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -71,9 +71,9 @@ class TutorialAssistant(Role):
|
|||
Returns:
|
||||
A message containing the result of the action.
|
||||
"""
|
||||
todo = self._rc.todo
|
||||
todo = self.rc.todo
|
||||
if type(todo) is WriteDirectory:
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
self.topic = msg.content
|
||||
resp = await todo.run(topic=self.topic)
|
||||
logger.info(resp)
|
||||
|
|
@ -90,4 +90,5 @@ class TutorialAssistant(Role):
|
|||
msg = await super().react()
|
||||
root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8"))
|
||||
msg.content = str(root_path / f"{self.main_title}.md")
|
||||
return msg
|
||||
|
|
|
|||
|
|
@ -23,9 +23,17 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, TypedDict, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -46,7 +54,74 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class RawMessage(TypedDict):
|
||||
class SerializationMixin(BaseModel):
|
||||
"""
|
||||
PolyMorphic subclasses Serialization / Deserialization Mixin
|
||||
- First of all, we need to know that pydantic is not designed for polymorphism.
|
||||
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
|
||||
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
|
||||
|
||||
More details:
|
||||
- https://docs.pydantic.dev/latest/concepts/serialization/
|
||||
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
|
||||
"""
|
||||
|
||||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
||||
cls.__is_polymorphic_base = is_polymorphic_base
|
||||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
||||
|
|
@ -102,33 +177,64 @@ class Documents(BaseModel):
|
|||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
id: str # According to Section 2.2.3.1.1 of RFC 135
|
||||
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
||||
content: str
|
||||
instruct_content: BaseModel = None
|
||||
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = ""
|
||||
sent_from: str = ""
|
||||
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
|
||||
cause_by: str = Field(default="", validate_default=True)
|
||||
sent_from: str = Field(default="", validate_default=True)
|
||||
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
ic = kwargs.get("instruct_content", None)
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def check_id(cls, id: str) -> str:
|
||||
return id if id else uuid.uuid4().hex
|
||||
|
||||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
kwargs["instruct_content"] = ic_new
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
|
||||
kwargs["content"] = kwargs.get("content", content)
|
||||
kwargs["cause_by"] = any_to_str(
|
||||
kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
)
|
||||
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
|
||||
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
|
||||
super(Message, self).__init__(**kwargs)
|
||||
@field_validator("cause_by", mode="before")
|
||||
@classmethod
|
||||
def check_cause_by(cls, cause_by: Any) -> str:
|
||||
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
|
||||
@field_validator("sent_from", mode="before")
|
||||
@classmethod
|
||||
def check_sent_from(cls, sent_from: Any) -> str:
|
||||
return any_to_str(sent_from if sent_from else "")
|
||||
|
||||
@field_validator("send_to", mode="before")
|
||||
@classmethod
|
||||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
|
||||
ic_dict = None
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
data["content"] = data.get("content", content)
|
||||
super().__init__(**data)
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
|
|
@ -142,28 +248,11 @@ class Message(BaseModel):
|
|||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
"""overwrite the `dict` to dump dynamic pydantic model"""
|
||||
obj_dict = super(Message, self).dict(*args, **kwargs)
|
||||
ic = self.instruct_content
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
return obj_dict
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
if self.instruct_content:
|
||||
return f"{self.role}: {self.instruct_content.dict()}"
|
||||
else:
|
||||
return f"{self.role}: {self.content}"
|
||||
return f"{self.role}: {self.instruct_content.model_dump()}"
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
|
@ -174,14 +263,25 @@ class Message(BaseModel):
|
|||
|
||||
def dump(self) -> str:
|
||||
"""Convert the object to json string"""
|
||||
return self.json(exclude_none=True)
|
||||
return self.model_dump_json(exclude_none=True, warnings=False)
|
||||
|
||||
@staticmethod
|
||||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
i = json.loads(val)
|
||||
return Message(**i)
|
||||
|
||||
try:
|
||||
m = json.loads(val)
|
||||
id = m.get("id")
|
||||
if "id" in m:
|
||||
del m["id"]
|
||||
msg = Message(**m)
|
||||
if id:
|
||||
msg.id = id
|
||||
return msg
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {val}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
|
|
@ -214,19 +314,9 @@ class AIMessage(Message):
|
|||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
_queue: Queue = Field(default_factory=Queue)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_private_attributes = {"_queue": Queue()}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
else:
|
||||
object.__setattr__(self, key, Queue())
|
||||
_queue: Queue = PrivateAttr(default_factory=Queue)
|
||||
|
||||
def pop(self) -> Message | None:
|
||||
"""Pop one message from the queue."""
|
||||
|
|
@ -262,16 +352,21 @@ class MessageQueue(BaseModel):
|
|||
return "[]"
|
||||
|
||||
lst = []
|
||||
msgs = []
|
||||
try:
|
||||
while True:
|
||||
item = await wait_for(self._queue.get(), timeout=1.0)
|
||||
if item is None:
|
||||
break
|
||||
lst.append(item.dict(exclude_none=True))
|
||||
msgs.append(item)
|
||||
lst.append(item.dump())
|
||||
self._queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Queue is empty, exiting...")
|
||||
return json.dumps(lst)
|
||||
finally:
|
||||
for m in msgs:
|
||||
self._queue.put_nowait(m)
|
||||
return json.dumps(lst, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> "MessageQueue":
|
||||
|
|
@ -280,7 +375,7 @@ class MessageQueue(BaseModel):
|
|||
try:
|
||||
lst = json.loads(data)
|
||||
for i in lst:
|
||||
msg = Message(**i)
|
||||
msg = Message.load(i)
|
||||
queue.push(msg)
|
||||
except JSONDecodeError as e:
|
||||
logger.warning(f"JSON load failed: {data}, error:{e}")
|
||||
|
|
@ -302,28 +397,28 @@ class BaseContext(BaseModel, ABC):
|
|||
|
||||
class CodingContext(BaseContext):
|
||||
filename: str
|
||||
design_doc: Optional[Document]
|
||||
task_doc: Optional[Document]
|
||||
code_doc: Optional[Document]
|
||||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
filename: str
|
||||
code_doc: Document
|
||||
test_doc: Optional[Document]
|
||||
test_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class RunCodeContext(BaseContext):
|
||||
mode: str = "script"
|
||||
code: Optional[str]
|
||||
code: Optional[str] = None
|
||||
code_filename: str = ""
|
||||
test_code: Optional[str]
|
||||
test_code: Optional[str] = None
|
||||
test_filename: str = ""
|
||||
command: List[str] = Field(default_factory=list)
|
||||
working_directory: str = ""
|
||||
additional_python_paths: List[str] = Field(default_factory=list)
|
||||
output_filename: Optional[str]
|
||||
output: Optional[str]
|
||||
output_filename: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RunCodeResult(BaseContext):
|
||||
|
|
|
|||
4
metagpt/strategy/__init__.py
Normal file
4
metagpt/strategy/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
109
metagpt/strategy/base.py
Normal file
109
metagpt/strategy/base.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:16 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from abc import ABC
|
||||
from typing import List
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseParser(BaseModel, ABC):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, input: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseEvaluator(BaseModel, ABC):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def status_verify(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ThoughtNode(Node):
|
||||
"""A node representing a thought in the thought tree."""
|
||||
|
||||
name: str = ""
|
||||
value: int = 0
|
||||
id: int = 0
|
||||
valid_status: bool = True
|
||||
|
||||
def update_value(self, value) -> None:
|
||||
"""Update the value of the thought node."""
|
||||
self.value = value
|
||||
|
||||
def update_valid_status(self, status) -> None:
|
||||
"""Update the validity status of the thought node."""
|
||||
self.valid_status = status
|
||||
|
||||
|
||||
class ThoughtTree(RenderTree):
|
||||
"""A tree structure to represent thoughts."""
|
||||
|
||||
@property
|
||||
def all_nodes(self) -> List[ThoughtNode]:
|
||||
"""
|
||||
Get a list of all nodes in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list containing all nodes in the thought tree.
|
||||
"""
|
||||
all_nodes = [node for _, _, node in self]
|
||||
return all_nodes
|
||||
|
||||
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Update the tree with new thoughts.
|
||||
|
||||
Args:
|
||||
thought (List[dict]): A list of dictionaries representing thought information.
|
||||
current_node (ThoughtNode): The current node under which new thoughts will be added.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
|
||||
"""
|
||||
nodes = []
|
||||
for node_info in thought:
|
||||
node = ThoughtNode(
|
||||
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
|
||||
)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def parse_node_path(self, node) -> List[str]:
|
||||
"""
|
||||
Parse and retrieve the hierarchical path of the given thought node.
|
||||
|
||||
This method traverses the parent nodes of the provided 'node' and constructs
|
||||
the full path from the root node to the given node.
|
||||
|
||||
Args:
|
||||
node: The thought node for which the hierarchical path needs to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list representing the full hierarchical path of the given thought node.
|
||||
The list is ordered from the root node to the provided node.
|
||||
"""
|
||||
full_node_path = []
|
||||
while node is not None:
|
||||
full_node_path.append(node.name)
|
||||
node = node.parent
|
||||
full_node_path.reverse()
|
||||
return full_node_path
|
||||
|
||||
def show(self) -> None:
|
||||
"""Print the updated tree."""
|
||||
print("\nUpdated Tree:")
|
||||
for pre, _, node in self:
|
||||
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
|
||||
277
metagpt/strategy/tot.py
Normal file
277
metagpt/strategy/tot.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.strategy.base import ThoughtNode, ThoughtTree
|
||||
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
OUTPUT_FORMAT = """
|
||||
Each output should be strictly a list of nodes, in json format, like this:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"node_id": str = "unique identifier for a solution, can be an ordinal",
|
||||
"node_state_instruction": "specified sample of solution",
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class ThoughtSolverBase(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
thought_tree: Optional[ThoughtTree] = Field(default=None)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.llm.use_system_prompt = False
|
||||
|
||||
async def solve(self, init_prompt):
|
||||
"""
|
||||
Solve method for subclasses to implement.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the solve method")
|
||||
|
||||
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Generate children thoughts based on the current state.
|
||||
|
||||
Args:
|
||||
current_state (str): The current state for which thoughts are generated.
|
||||
current_node (ThoughtNode): The current node in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of nodes representing the generated thoughts.
|
||||
"""
|
||||
state_prompt = self.config.parser.propose(
|
||||
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
|
||||
)
|
||||
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
|
||||
thoughts = CodeParser.parse_code(block="", text=rsp)
|
||||
thoughts = eval(thoughts)
|
||||
# fixme 避免不跟随,生成过多nodes
|
||||
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
|
||||
return self.thought_tree.update_node(thoughts, current_node=current_node)
|
||||
|
||||
async def evaluate_node(self, node, parent_value) -> None:
|
||||
"""
|
||||
Evaluate a node and update its status and value.
|
||||
|
||||
Args:
|
||||
node (ThoughtNode): The node to be evaluated.
|
||||
parent_value (float): The parent node's value.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
|
||||
evaluation = await self.llm.aask(msg=eval_prompt)
|
||||
|
||||
value = self.config.evaluator(evaluation, **{"node_id": node.id})
|
||||
status = self.config.evaluator.status_verify(value)
|
||||
|
||||
node.update_valid_status(status=status)
|
||||
# 累计分数
|
||||
node.update_value(parent_value + value)
|
||||
|
||||
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
|
||||
"""
|
||||
Select nodes based on the configured selection method.
|
||||
|
||||
Args:
|
||||
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of selected nodes.
|
||||
"""
|
||||
# nodes to be selected
|
||||
nodes = []
|
||||
if self.config.method_select == MethodSelect.SAMPLE:
|
||||
raise NotImplementedError
|
||||
elif self.config.method_select == MethodSelect.GREEDY:
|
||||
nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
|
||||
for node in thought_nodes:
|
||||
if node not in nodes:
|
||||
node.parent = None # 从树中删除节点
|
||||
return nodes
|
||||
|
||||
def update_solution(self):
|
||||
"""
|
||||
Select the result with the highest score.
|
||||
|
||||
Returns:
|
||||
- List[ThoughtNode]: List of nodes representing the best solution.
|
||||
- List[str]: List of node names forming the best solution path.
|
||||
"""
|
||||
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
|
||||
best_solution_path = self.thought_tree.parse_node_path(best_node)
|
||||
return [best_node], best_solution_path
|
||||
|
||||
|
||||
class BFSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through BFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
current_nodes = [root]
|
||||
for step in range(self.config.max_steps):
|
||||
solutions = await self._bfs_build(current_nodes)
|
||||
|
||||
selected_nodes = self.select_nodes(solutions)
|
||||
current_nodes = selected_nodes
|
||||
|
||||
self.thought_tree.show()
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
async def _bfs_build(self, current_nodes):
|
||||
"""
|
||||
Build the thought tree using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
current_nodes (List[ThoughtNode]): Current nodes to expand.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
|
||||
"""
|
||||
tasks = []
|
||||
for node in current_nodes:
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
|
||||
|
||||
thought_nodes_list = await asyncio.gather(*tasks)
|
||||
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
|
||||
return solutions
|
||||
|
||||
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await asyncio.gather(
|
||||
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
|
||||
)
|
||||
return thought_nodes
|
||||
|
||||
|
||||
class DFSSolver(ThoughtSolverBase):
|
||||
async def _dfs(self, root_node):
|
||||
"""
|
||||
Perform Depth-First Search (DFS) on the thought tree.
|
||||
|
||||
Args:
|
||||
root_node (ThoughtNode): The root node of the thought tree.
|
||||
|
||||
Returns:
|
||||
List[str]: The solution path obtained through DFS.
|
||||
"""
|
||||
impossible_state_cnt = 0
|
||||
node = root_node
|
||||
for step in range(self.max_steps):
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
|
||||
if thought_nodes[0].valid_status is False:
|
||||
impossible_state_cnt += 1
|
||||
if impossible_state_cnt >= 2:
|
||||
logger.info("impossible state reached, break")
|
||||
break
|
||||
node = thought_nodes[0]
|
||||
_solution_path = self.thought_tree.parse_node_path(node)
|
||||
self.thought_tree.show()
|
||||
|
||||
return _solution_path
|
||||
|
||||
async def solve(self, init_prompt="", root=ThoughtNode("")):
|
||||
"""
|
||||
Solve the problem using Depth-First Search (DFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through DFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
for n in range(self.config.n_solution_sample):
|
||||
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
|
||||
await self._dfs(root)
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
|
||||
class MCTSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TreeofThought(BaseModel):
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
|
||||
strategy: Strategy = Field(default=Strategy.BFS)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._initialize_solver(self.strategy)
|
||||
|
||||
def _initialize_solver(self, strategy):
|
||||
"""
|
||||
Initialize the solver based on the chosen strategy.
|
||||
|
||||
Args:
|
||||
strategy (Strategy): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
ThoughtSolverBase: An instance of the appropriate solver.
|
||||
"""
|
||||
if strategy == Strategy.BFS:
|
||||
self.solver = BFSSolver(config=self.config)
|
||||
elif strategy == Strategy.DFS:
|
||||
self.solver = DFSSolver(config=self.config)
|
||||
elif strategy == Strategy.MCTS:
|
||||
self.solver = MCTSSolver(config=self.config)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
|
||||
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using the specified strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
strategy (str): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
Any: The solution obtained using the selected strategy.
|
||||
"""
|
||||
await self.solver.solve(init_prompt)
|
||||
30
metagpt/strategy/tot_schema.py
Normal file
30
metagpt/strategy/tot_schema.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:14 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.strategy.base import BaseEvaluator, BaseParser
|
||||
|
||||
|
||||
class MethodSelect(Enum):
|
||||
SAMPLE = "sample"
|
||||
GREEDY = "greedy"
|
||||
|
||||
|
||||
class Strategy(Enum):
|
||||
BFS = "BFS"
|
||||
DFS = "DFS"
|
||||
MCTS = "MCTS"
|
||||
|
||||
|
||||
class ThoughtSolverConfig(BaseModel):
|
||||
max_steps: int = 3
|
||||
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
|
||||
n_generate_sample: int = 5 # per node
|
||||
n_select_sample: int = 3 # per path
|
||||
n_solution_sample: int = 5 # only for dfs
|
||||
parser: BaseParser = Field(default_factory=BaseParser)
|
||||
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -33,10 +33,9 @@ class SubscriptionRunner(BaseModel):
|
|||
>>> asyncio.run(main())
|
||||
"""
|
||||
|
||||
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@
|
|||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -34,32 +35,27 @@ class Team(BaseModel):
|
|||
dedicated to env any multi-agent activity, such as collaboratively writing executable code.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
env: Environment = Field(default_factory=Environment)
|
||||
investment: float = Field(default=10.0)
|
||||
idea: str = Field(default="")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if "roles" in kwargs:
|
||||
self.hire(kwargs["roles"])
|
||||
if "env_desc" in kwargs:
|
||||
self.env.desc = kwargs["env_desc"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
def __init__(self, **data: Any):
|
||||
super(Team, self).__init__(**data)
|
||||
if "roles" in data:
|
||||
self.hire(data["roles"])
|
||||
if "env_desc" in data:
|
||||
self.env.desc = data["env_desc"]
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
write_json_file(team_info_path, self.dict(exclude={"env": True}))
|
||||
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
|
||||
|
||||
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
|
||||
|
||||
@classmethod
|
||||
def recover(cls, stg_path: Path) -> "Team":
|
||||
return cls.deserialize(stg_path)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Team":
|
||||
"""stg_path = ./storage/team"""
|
||||
|
|
@ -76,7 +72,6 @@ class Team(BaseModel):
|
|||
# recover environment
|
||||
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
|
||||
team_info.update({"env": environment})
|
||||
|
||||
team = Team(**team_info)
|
||||
return team
|
||||
|
||||
|
|
@ -90,9 +85,12 @@ class Team(BaseModel):
|
|||
CONFIG.max_budget = investment
|
||||
logger.info(f"Investment: ${investment}.")
|
||||
|
||||
def _check_balance(self):
|
||||
if CONFIG.total_cost > CONFIG.max_budget:
|
||||
raise NoMoneyException(CONFIG.total_cost, f"Insufficient funds: {CONFIG.max_budget}")
|
||||
@staticmethod
|
||||
def _check_balance():
|
||||
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
|
||||
raise NoMoneyException(
|
||||
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
|
||||
)
|
||||
|
||||
def run_project(self, idea, send_to: str = ""):
|
||||
"""Run a project from publishing user requirement."""
|
||||
|
|
@ -100,7 +98,8 @@ class Team(BaseModel):
|
|||
|
||||
# Human requirement.
|
||||
self.env.publish_message(
|
||||
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL)
|
||||
Message(role="Human", content=idea, cause_by=UserRequirement, send_to=send_to or MESSAGE_ROUTE_TO_ALL),
|
||||
peekable=False,
|
||||
)
|
||||
|
||||
def start_project(self, idea, send_to: str = ""):
|
||||
|
|
@ -117,10 +116,10 @@ class Team(BaseModel):
|
|||
return self.run_project(idea=idea, send_to=send_to)
|
||||
|
||||
def _save(self):
|
||||
logger.info(self.json(ensure_ascii=False))
|
||||
logger.info(self.model_dump_json())
|
||||
|
||||
@serialize_decorator
|
||||
async def run(self, n_round=3, idea="", send_to=""):
|
||||
async def run(self, n_round=3, idea="", send_to="", auto_archive=True):
|
||||
"""Run company until target round or no money"""
|
||||
if idea:
|
||||
self.run_project(idea=idea, send_to=send_to)
|
||||
|
|
@ -132,6 +131,5 @@ class Team(BaseModel):
|
|||
self._check_balance()
|
||||
|
||||
await self.env.run()
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
self.env.archive(auto_archive)
|
||||
return self.env.history
|
||||
|
|
|
|||
|
|
@ -22,3 +22,8 @@ class WebBrowserEngineType(Enum):
|
|||
PLAYWRIGHT = "playwright"
|
||||
SELENIUM = "selenium"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@classmethod
|
||||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
|
|
|||
|
|
@ -4,39 +4,102 @@
|
|||
@Time : 2023/6/9 22:22
|
||||
@Author : Leo Xiao
|
||||
@File : azure_tts.py
|
||||
@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import aiofiles
|
||||
from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class AzureTTS:
|
||||
"""https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles"""
|
||||
"""Azure Text-to-Speech"""
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, lang, voice, role, text, output_file):
|
||||
subscription_key = CONFIG.get("AZURE_TTS_SUBSCRIPTION_KEY")
|
||||
region = CONFIG.get("AZURE_TTS_REGION")
|
||||
speech_config = SpeechConfig(subscription=subscription_key, region=region)
|
||||
def __init__(self, subscription_key, region):
|
||||
"""
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
"""
|
||||
self.subscription_key = subscription_key if subscription_key else CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
self.region = region if region else CONFIG.AZURE_TTS_REGION
|
||||
|
||||
# 参数参考:https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles
|
||||
async def synthesize_speech(self, lang, voice, text, output_file):
|
||||
speech_config = SpeechConfig(subscription=self.subscription_key, region=self.region)
|
||||
speech_config.speech_synthesis_voice_name = voice
|
||||
audio_config = AudioConfig(filename=output_file)
|
||||
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
|
||||
|
||||
# if voice=="zh-CN-YunxiNeural":
|
||||
ssml_string = f"""
|
||||
<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>
|
||||
<voice name='{voice}'>
|
||||
<mstts:express-as style='affectionate' role='{role}'>
|
||||
{text}
|
||||
</mstts:express-as>
|
||||
</voice>
|
||||
</speak>
|
||||
"""
|
||||
# More detail: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice
|
||||
ssml_string = (
|
||||
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
f"xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{voice}'>{text}</voice></speak>"
|
||||
)
|
||||
|
||||
synthesizer.speak_ssml_async(ssml_string).get()
|
||||
return synthesizer.speak_ssml_async(ssml_string).get()
|
||||
|
||||
@staticmethod
|
||||
def role_style_text(role, style, text):
|
||||
return f'<mstts:express-as role="{role}" style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def role_text(role, text):
|
||||
return f'<mstts:express-as role="{role}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def style_text(style, text):
|
||||
return f'<mstts:express-as style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
azure_tts = AzureTTS()
|
||||
azure_tts.synthesize_speech("zh-CN", "zh-CN-YunxiNeural", "Boy", "Hello, I am Kaka", "output.wav")
|
||||
# Export
|
||||
async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscription_key="", region=""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
|
||||
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param text: The text used for voice conversion.
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
:return: Returns the Base64-encoded .wav file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
if not lang:
|
||||
lang = "zh-CN"
|
||||
if not voice:
|
||||
voice = "zh-CN-XiaomoNeural"
|
||||
if not role:
|
||||
role = "Girl"
|
||||
if not style:
|
||||
style = "affectionate"
|
||||
if not subscription_key:
|
||||
subscription_key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
if not region:
|
||||
region = CONFIG.AZURE_TTS_REGION
|
||||
|
||||
xml_value = AzureTTS.role_style_text(role=role, style=style, text=text)
|
||||
tts = AzureTTS(subscription_key=subscription_key, region=region)
|
||||
filename = Path(__file__).resolve().parent / (str(uuid4()).replace("-", "") + ".wav")
|
||||
try:
|
||||
await tts.synthesize_speech(lang=lang, voice=voice, text=xml_value, output_file=str(filename))
|
||||
async with aiofiles.open(filename, mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
return ""
|
||||
finally:
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
|
|
|
|||
|
|
@ -1,197 +0,0 @@
|
|||
import inspect
|
||||
import re
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import wrapt
|
||||
from interpreter.core.core import Interpreter
|
||||
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
||||
|
||||
def extract_python_code(code: str):
|
||||
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
|
||||
# Use regular expressions to match comment blocks and related code.
|
||||
pattern = r"(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)"
|
||||
matches = re.findall(pattern, code, re.DOTALL)
|
||||
|
||||
# Extract the last code block when encountering the same comment.
|
||||
unique_comments = {}
|
||||
for comment, code_block in matches:
|
||||
unique_comments[comment] = code_block
|
||||
|
||||
# concatenate into functional form
|
||||
result_code = "\n".join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
header_code = code[: code.find("#")]
|
||||
code = header_code + result_code
|
||||
|
||||
logger.info(f"Extract python code: \n {highlight(code)}")
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class OpenCodeInterpreter(object):
|
||||
"""https://github.com/KillianLucas/open-interpreter"""
|
||||
|
||||
def __init__(self, auto_run: bool = True) -> None:
|
||||
interpreter = Interpreter()
|
||||
interpreter.auto_run = auto_run
|
||||
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
|
||||
interpreter.api_key = CONFIG.openai_api_key
|
||||
self.interpreter = interpreter
|
||||
|
||||
def chat(self, query: str, reset: bool = True):
|
||||
if reset:
|
||||
self.interpreter.reset()
|
||||
return self.interpreter.chat(query)
|
||||
|
||||
@staticmethod
|
||||
def extract_function(
|
||||
query_respond: List, function_name: str, *, language: str = "python", function_format: str = None
|
||||
) -> str:
|
||||
"""create a function from query_respond."""
|
||||
if language not in ("python"):
|
||||
raise NotImplementedError(f"Not support to parse language {language}!")
|
||||
|
||||
# set function form
|
||||
if function_format is None:
|
||||
assert language == "python", f"Expect python language for default function_format, but got {language}."
|
||||
function_format = """def {function_name}():\n{code}"""
|
||||
# Extract the code module in the open-interpreter respond message.
|
||||
# The query_respond of open-interpreter before v0.1.4 is:
|
||||
# [{'role': 'user', 'content': your query string},
|
||||
# {'role': 'assistant', 'content': plan from llm, 'function_call': {
|
||||
# "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
|
||||
# "parsed_arguments": {"language": "python", "code": code of first plan}
|
||||
# ...]
|
||||
if "function_call" in query_respond[1]:
|
||||
code = [
|
||||
item["function_call"]["parsed_arguments"]["code"]
|
||||
for item in query_respond
|
||||
if "function_call" in item
|
||||
and "parsed_arguments" in item["function_call"]
|
||||
and "language" in item["function_call"]["parsed_arguments"]
|
||||
and item["function_call"]["parsed_arguments"]["language"] == language
|
||||
]
|
||||
# The query_respond of open-interpreter v0.1.7 is:
|
||||
# [{'role': 'user', 'message': your query string},
|
||||
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
|
||||
# 'code': code of first plan, 'output': output of first plan code},
|
||||
# ...]
|
||||
elif "code" in query_respond[1]:
|
||||
code = [
|
||||
item["code"]
|
||||
for item in query_respond
|
||||
if "code" in item and "language" in item and item["language"] == language
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
|
||||
# add indent.
|
||||
indented_code_str = textwrap.indent("\n".join(code), " " * 4)
|
||||
# Return the code after deduplication.
|
||||
if language == "python":
|
||||
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
|
||||
|
||||
|
||||
def gen_query(func: Callable, args, kwargs) -> str:
|
||||
# Get the annotation of the function as part of the query.
|
||||
desc = func.__doc__
|
||||
signature = inspect.signature(func)
|
||||
# Get the signature of the wrapped function and the assignment of the input parameters as part of the query.
|
||||
bound_args = signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
|
||||
return query
|
||||
|
||||
|
||||
def gen_template_fun(func: Callable) -> str:
|
||||
return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
|
||||
|
||||
|
||||
class OpenInterpreterDecorator(object):
|
||||
def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
|
||||
self.save_code = save_code
|
||||
self.code_file_path = code_file_path
|
||||
self.clear_code = clear_code
|
||||
|
||||
def _have_code(self, rsp: List[Dict]):
|
||||
# Is there any code generated?
|
||||
return "code" in rsp[1] and rsp[1]["code"] not in ("", None)
|
||||
|
||||
def _is_faild_plan(self, rsp: List[Dict]):
|
||||
# is faild plan?
|
||||
func_code = OpenCodeInterpreter.extract_function(rsp, "function")
|
||||
# If there is no more than 1 '\n', the plan execution fails.
|
||||
if isinstance(func_code, str) and func_code.count("\n") <= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
|
||||
for _ in range(max_try):
|
||||
# TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
|
||||
if self._have_code(respond) and not self._is_faild_plan(respond):
|
||||
break
|
||||
elif not self._have_code(respond):
|
||||
logger.warning(f"llm did not return executable code, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
elif self._is_faild_plan(respond):
|
||||
logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
|
||||
# Post-processing of respond
|
||||
if not self._have_code(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if self._is_faild_plan(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return respond
|
||||
|
||||
def __call__(self, wrapped):
|
||||
@wrapt.decorator
|
||||
async def wrapper(wrapped: Callable, instance, args, kwargs):
|
||||
# Get the decorated function name.
|
||||
func_name = wrapped.__name__
|
||||
# If the script exists locally and clearcode is not required, execute the function from the script.
|
||||
if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
|
||||
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
|
||||
|
||||
# Auto run generate code by using open-interpreter.
|
||||
interpreter = OpenCodeInterpreter()
|
||||
query = gen_query(wrapped, args, kwargs)
|
||||
logger.info(f"query for OpenCodeInterpreter: \n {query}")
|
||||
respond = interpreter.chat(query)
|
||||
# Make sure the response is as expected.
|
||||
respond = self._check_respond(query, interpreter, respond, 3)
|
||||
# Assemble the code blocks generated by open-interpreter into a function without parameters.
|
||||
func_code = interpreter.extract_function(respond, func_name)
|
||||
# Clone the `func_code` into wrapped, that is,
|
||||
# keep the `func_code` and wrapped functions with the same input parameter and return value types.
|
||||
template_func = gen_template_fun(wrapped)
|
||||
cf = CloneFunction()
|
||||
code = await cf.run(template_func=template_func, source_code=func_code)
|
||||
# Display the generated function in the terminal.
|
||||
logger_code = highlight(code, "python")
|
||||
logger.info(f"Creating following Python function:\n{logger_code}")
|
||||
# execute this function.
|
||||
try:
|
||||
res = run_function_code(code, func_name, *args, **kwargs)
|
||||
if self.save_code and self.code_file_path:
|
||||
cf._save(self.code_file_path, code)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
|
||||
raise Exception("Could not evaluate Python code", e)
|
||||
return res
|
||||
|
||||
return wrapper(wrapped)
|
||||
152
metagpt/tools/iflytek_tts.py
Normal file
152
metagpt/tools/iflytek_tts.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : iflytek_tts.py
|
||||
@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import aiofiles
|
||||
import websockets as websockets
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class IFlyTekTTSStatus(Enum):
|
||||
STATUS_FIRST_FRAME = 0 # The first frame
|
||||
STATUS_CONTINUE_FRAME = 1 # The intermediate frame
|
||||
STATUS_LAST_FRAME = 2 # The last frame
|
||||
|
||||
|
||||
class AudioData(BaseModel):
|
||||
audio: str
|
||||
status: int
|
||||
ced: str
|
||||
|
||||
|
||||
class IFlyTekTTSResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[AudioData] = None
|
||||
sid: str
|
||||
|
||||
|
||||
DEFAULT_IFLYTEK_VOICE = "xiaoyan"
|
||||
|
||||
|
||||
class IFlyTekTTS(object):
|
||||
def __init__(self, app_id: str, api_key: str, api_secret: str):
|
||||
"""
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
"""
|
||||
self.app_id = app_id or CONFIG.IFLYTEK_APP_ID
|
||||
self.api_key = api_key or CONFIG.IFLYTEK_API_KEY
|
||||
self.api_secret = api_secret or CONFIG.API_SECRET
|
||||
|
||||
async def synthesize_speech(self, text, output_file: str, voice=DEFAULT_IFLYTEK_VOICE):
|
||||
url = self._create_url()
|
||||
data = {
|
||||
"common": {"app_id": self.app_id},
|
||||
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": voice, "tte": "utf8"},
|
||||
"data": {"status": 2, "text": str(base64.b64encode(text.encode("utf-8")), "UTF8")},
|
||||
}
|
||||
req = json.dumps(data)
|
||||
async with websockets.connect(url) as websocket:
|
||||
# send request
|
||||
await websocket.send(req)
|
||||
|
||||
# receive frames
|
||||
async with aiofiles.open(str(output_file), "wb") as writer:
|
||||
while True:
|
||||
v = await websocket.recv()
|
||||
rsp = IFlyTekTTSResponse(**json.loads(v))
|
||||
if rsp.data:
|
||||
binary_data = base64.b64decode(rsp.data.audio)
|
||||
await writer.write(binary_data)
|
||||
if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value:
|
||||
continue
|
||||
break
|
||||
|
||||
def _create_url(self):
|
||||
"""Create request url"""
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
# Generate a timestamp in RFC1123 format
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# Perform HMAC-SHA256 encryption
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
||||
self.api_key,
|
||||
"hmac-sha256",
|
||||
"host date request-line",
|
||||
signature_sha,
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
# Combine the authentication parameters of the request into a dictionary.
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
# Concatenate the authentication parameters to generate the URL.
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key: str = "", api_secret: str = ""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://www.xfyun.cn/doc/tts/online_tts/API.html`
|
||||
|
||||
:param voice: Default `xiaoyan`. For more details, checkout: `https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B`
|
||||
:param text: The text used for voice conversion.
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:return: Returns the Base64-encoded .mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not app_id:
|
||||
app_id = CONFIG.IFLYTEK_APP_ID
|
||||
if not api_key:
|
||||
api_key = CONFIG.IFLYTEK_API_KEY
|
||||
if not api_secret:
|
||||
api_secret = CONFIG.IFLYTEK_API_SECRET
|
||||
if not voice:
|
||||
voice = CONFIG.IFLYTEK_VOICE or DEFAULT_IFLYTEK_VOICE
|
||||
|
||||
filename = Path(__file__).parent / (uuid.uuid4().hex + ".mp3")
|
||||
try:
|
||||
tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret)
|
||||
await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice)
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
base64_string = ""
|
||||
finally:
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
26
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
26
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : metagpt_oas3_api_svc.py
|
||||
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
|
||||
def oas_http_svc():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service"""
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("metagpt_oas3_api.yaml")
|
||||
app.add_api("openapi.yaml")
|
||||
app.run(port=8080)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
oas_http_svc()
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue